worker.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. from typing import Any, Dict, List, Optional, Set, Tuple
  5. import torch
  6. import torch.distributed
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  9. LoRAConfig, ModelConfig, ParallelConfig,
  10. SchedulerConfig, VisionLanguageConfig)
  11. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  12. from aphrodite.distributed import (broadcast_tensor_dict,
  13. ensure_model_parallel_initialized,
  14. get_tensor_model_parallel_cpu_group,
  15. init_distributed_environment)
  16. from aphrodite.distributed.device_communicators import pynccl_utils
  17. from aphrodite.distributed.device_communicators.custom_all_reduce import \
  18. init_custom_ar
  19. from aphrodite.lora.request import LoRARequest
  20. from aphrodite.modeling import set_random_seed
  21. from aphrodite.task_handler.cache_engine import CacheEngine
  22. from aphrodite.task_handler.model_runner import ModelRunner
  23. from aphrodite.task_handler.worker_base import WorkerBase
  24. class Worker(WorkerBase):
  25. """A worker class that executes (a partition of) the model on a GPU.
  26. Each worker is associated with a single GPU. The worker is responsible for
  27. maintaining the KV cache and executing the model on the GPU. In case of
  28. distributed inference, each worker is assigned a partition of the model.
  29. """
  30. def __init__(
  31. self,
  32. model_config: ModelConfig,
  33. parallel_config: ParallelConfig,
  34. scheduler_config: SchedulerConfig,
  35. device_config: DeviceConfig,
  36. cache_config: CacheConfig,
  37. load_config: LoadConfig,
  38. local_rank: int,
  39. rank: int,
  40. distributed_init_method: str,
  41. lora_config: Optional[LoRAConfig] = None,
  42. vision_language_config: Optional[VisionLanguageConfig] = None,
  43. is_driver_worker: bool = False,
  44. ) -> None:
  45. self.model_config = model_config
  46. self.parallel_config = parallel_config
  47. self.scheduler_config = scheduler_config
  48. self.device_config = device_config
  49. self.cache_config = cache_config
  50. self.local_rank = local_rank
  51. self.rank = rank
  52. self.distributed_init_method = distributed_init_method
  53. self.lora_config = lora_config
  54. self.load_config = load_config
  55. self.is_driver_worker = is_driver_worker
  56. if self.is_driver_worker:
  57. assert self.rank == 0, "The driver worker must have rank 0."
  58. if self.model_config.trust_remote_code:
  59. # note: lazy import to avoid importing torch before initializing
  60. from aphrodite.common.utils import init_cached_hf_modules
  61. init_cached_hf_modules()
  62. self.vision_language_config = vision_language_config
  63. if self.vision_language_config:
  64. assert not self.lora_config, (
  65. "To be tested: vision language model with LoRA settings.")
  66. self.model_runner = ModelRunner(
  67. model_config,
  68. parallel_config,
  69. scheduler_config,
  70. device_config,
  71. cache_config,
  72. load_config=load_config,
  73. lora_config=self.lora_config,
  74. kv_cache_dtype=self.cache_config.cache_dtype,
  75. is_driver_worker=is_driver_worker,
  76. vision_language_config=vision_language_config,
  77. )
  78. # Uninitialized cache engine. Will be initialized by
  79. # initialize_cache.
  80. self.cache_engine: CacheEngine
  81. self.gpu_cache: List[torch.Tensor]
  82. def init_device(self) -> None:
  83. if self.device_config.device.type == "cuda":
  84. # torch.distributed.all_reduce does not free the input tensor until
  85. # the synchronization point. This causes the memory usage to grow
  86. # as the number of all_reduce calls increases. This env var disables
  87. # this behavior.
  88. # Related issue:
  89. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  90. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  91. # This env var set by Ray causes exceptions with graph building.
  92. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  93. self.device = torch.device(f"cuda:{self.local_rank}")
  94. torch.cuda.set_device(self.device)
  95. _check_if_gpu_supports_dtype(self.model_config.dtype)
  96. torch.cuda.empty_cache()
  97. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  98. else:
  99. raise RuntimeError(
  100. f"Not support device type: {self.device_config.device}")
  101. # Initialize the distributed environment.
  102. init_worker_distributed_environment(self.parallel_config, self.rank,
  103. self.distributed_init_method,
  104. self.local_rank)
  105. # Set random seed.
  106. set_random_seed(self.model_config.seed)
  107. def load_model(self):
  108. self.model_runner.load_model()
  109. @torch.inference_mode()
  110. def determine_num_available_blocks(self) -> Tuple[int, int]:
  111. """Profiles the peak memory usage of the model to determine how many
  112. KV blocks may be allocated without OOMs.
  113. The engine will first conduct a profiling of the existing memory usage.
  114. Then, it calculate the maximum possible number of GPU and CPU blocks
  115. that can be allocated with the remaining free memory.
  116. .. tip::
  117. You may limit the usage of GPU memory
  118. by adjusting the `gpu_memory_utilization` parameter.
  119. """
  120. # Profile the memory usage of the model and get the maximum number of
  121. # cache blocks that can be allocated with the remaining free memory.
  122. torch.cuda.empty_cache()
  123. # Execute a forward pass with dummy inputs to profile the memory usage
  124. # of the model.
  125. self.model_runner.profile_run()
  126. # Calculate the number of blocks that can be allocated with the
  127. # profiled peak memory.
  128. torch.cuda.synchronize()
  129. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  130. # NOTE: Here we assume that the other processes using the same
  131. # GPU did not change their memory usage during the profiling.
  132. peak_memory = self.init_gpu_memory - free_gpu_memory
  133. assert peak_memory > 0, (
  134. "Error in memory profiling. This happens when the GPU memory was "
  135. "not properly cleaned up before initializing Aphrodite.")
  136. cache_block_size = self.get_cache_block_size_bytes()
  137. num_gpu_blocks = int(
  138. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  139. peak_memory) // cache_block_size)
  140. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  141. cache_block_size)
  142. num_gpu_blocks = max(num_gpu_blocks, 0)
  143. num_cpu_blocks = max(num_cpu_blocks, 0)
  144. if self.model_runner.lora_manager:
  145. self.model_runner.remove_all_loras()
  146. gc.collect()
  147. torch.cuda.empty_cache()
  148. return num_gpu_blocks, num_cpu_blocks
  149. def initialize_cache(self, num_gpu_blocks: int,
  150. num_cpu_blocks: int) -> None:
  151. """Allocate GPU and CPU KV cache with the specified number of blocks.
  152. This also warms up the model, which may record CUDA graphs.
  153. """
  154. raise_if_cache_size_invalid(num_gpu_blocks,
  155. self.cache_config.block_size,
  156. self.model_config.max_model_len)
  157. self.cache_config.num_gpu_blocks = num_gpu_blocks
  158. self.cache_config.num_cpu_blocks = num_cpu_blocks
  159. self._init_cache_engine()
  160. self._warm_up_model()
  161. def _init_cache_engine(self):
  162. assert self.cache_config.num_gpu_blocks is not None
  163. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  164. self.parallel_config)
  165. self.gpu_cache = self.cache_engine.gpu_cache
  166. def _warm_up_model(self) -> None:
  167. if not self.model_config.enforce_eager:
  168. self.model_runner.capture_model(self.gpu_cache)
  169. # Reset the seed to ensure that the random state is not affected by
  170. # the model initialization and profiling.
  171. set_random_seed(self.model_config.seed)
  172. def cache_swap(
  173. self,
  174. blocks_to_swap_in: torch.Tensor,
  175. blocks_to_swap_out: torch.Tensor,
  176. blocks_to_copy: torch.Tensor,
  177. ) -> None:
  178. # Issue cache operations.
  179. if blocks_to_swap_in.numel() > 0:
  180. self.cache_engine.swap_in(blocks_to_swap_in)
  181. if blocks_to_swap_out.numel() > 0:
  182. self.cache_engine.swap_out(blocks_to_swap_out)
  183. if blocks_to_copy.numel() > 0:
  184. self.cache_engine.copy(blocks_to_copy)
  185. @torch.inference_mode()
  186. def execute_model(
  187. self,
  188. execute_model_req: Optional[ExecuteModelRequest] = None
  189. ) -> List[SamplerOutput]:
  190. if execute_model_req is None:
  191. seq_group_metadata_list = None
  192. else:
  193. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  194. blocks_to_swap_in: torch.Tensor
  195. blocks_to_swap_out: torch.Tensor
  196. blocks_to_copy: torch.Tensor
  197. if self.is_driver_worker:
  198. assert seq_group_metadata_list is not None
  199. assert execute_model_req is not None
  200. num_seq_groups = len(seq_group_metadata_list)
  201. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
  202. # they contain parameters to launch cudamemcpyasync.
  203. blocks_to_swap_in = torch.tensor(
  204. execute_model_req.blocks_to_swap_in,
  205. device="cpu",
  206. dtype=torch.int64).view(-1, 2)
  207. blocks_to_swap_out = torch.tensor(
  208. execute_model_req.blocks_to_swap_out,
  209. device="cpu",
  210. dtype=torch.int64).view(-1, 2)
  211. # `blocks_to_copy` is a gpu tensor. The src and tgt of
  212. # blocks to copy are in the same device, and `blocks_to_copy`
  213. # can be used directly within cuda kernels.
  214. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  215. device=self.device,
  216. dtype=torch.int64).view(-1, 2)
  217. data: Dict[str, Any] = {
  218. "num_seq_groups": num_seq_groups,
  219. "blocks_to_swap_in": blocks_to_swap_in,
  220. "blocks_to_swap_out": blocks_to_swap_out,
  221. "blocks_to_copy": blocks_to_copy,
  222. }
  223. broadcast_tensor_dict(data, src=0)
  224. else:
  225. data = broadcast_tensor_dict(src=0)
  226. num_seq_groups = data["num_seq_groups"]
  227. blocks_to_swap_in = data["blocks_to_swap_in"]
  228. blocks_to_swap_out = data["blocks_to_swap_out"]
  229. blocks_to_copy = data["blocks_to_copy"]
  230. self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
  231. # If there is no input, we don't need to execute the model.
  232. if num_seq_groups == 0:
  233. return []
  234. output = self.model_runner.execute_model(seq_group_metadata_list,
  235. self.gpu_cache)
  236. # Worker only supports single-step execution. Wrap the output in a list
  237. # to conform to interface.
  238. return [output]
  239. def add_lora(self, lora_request: LoRARequest) -> bool:
  240. return self.model_runner.add_lora(lora_request)
  241. def remove_lora(self, lora_id: int) -> bool:
  242. return self.model_runner.remove_lora(lora_id)
  243. def list_loras(self) -> Set[int]:
  244. return self.model_runner.list_loras()
  245. @property
  246. def max_model_len(self) -> int:
  247. return self.model_config.max_model_len
  248. @property
  249. def vocab_size(self) -> int:
  250. return self.model_runner.vocab_size
  251. def get_cache_block_size_bytes(self) -> int:
  252. """Get the size of the KV cache block size in bytes.
  253. """
  254. return CacheEngine.get_cache_block_size(self.cache_config,
  255. self.model_config,
  256. self.parallel_config)
  257. def init_worker_distributed_environment(
  258. parallel_config: ParallelConfig,
  259. rank: int,
  260. distributed_init_method: Optional[str] = None,
  261. local_rank: int = -1,
  262. ) -> None:
  263. """Initialize the distributed environment."""
  264. init_distributed_environment(parallel_config.world_size, rank,
  265. distributed_init_method, local_rank)
  266. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  267. parallel_config.pipeline_parallel_size)
  268. if pynccl_utils.is_initialized():
  269. pynccl_world_size = pynccl_utils.get_world_size()
  270. if pynccl_world_size != parallel_config.world_size:
  271. raise RuntimeError(
  272. "pynccl is already initialized but the pynccl world "
  273. "size does not match parallel_config.world_size "
  274. f"({pynccl_world_size} vs. {parallel_config.world_size}).")
  275. elif parallel_config.world_size > 1:
  276. # NOTE: We don't initialize pynccl process group when world size
  277. # is 1.
  278. # NOTE: By default, pynccl is initialized for tp group.
  279. pynccl_utils.init_process_group(
  280. group=get_tensor_model_parallel_cpu_group())
  281. # Initialize a custom fast all-reduce implementation.
  282. if not parallel_config.disable_custom_all_reduce:
  283. init_custom_ar()
  284. # A small all_reduce for warmup.
  285. torch.distributed.all_reduce(torch.zeros(1).cuda())
  286. if pynccl_utils.is_initialized():
  287. pynccl_utils.all_reduce(torch.zeros(1).cuda())
  288. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  289. # Check if the GPU supports the dtype.
  290. if torch_dtype == torch.bfloat16:
  291. compute_capability = torch.cuda.get_device_capability()
  292. if compute_capability[0] < 8:
  293. gpu_name = torch.cuda.get_device_name()
  294. raise ValueError(
  295. "Bfloat16 is only supported on GPUs with compute capability "
  296. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  297. f"{compute_capability[0]}.{compute_capability[1]}. "
  298. "You can use float16 instead by explicitly setting the"
  299. "`dtype` flag in CLI, for example: --dtype=half.")
  300. def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
  301. max_model_len) -> None:
  302. if num_gpu_blocks <= 0:
  303. raise ValueError("No available memory for the cache blocks. "
  304. "Try increasing `gpu_memory_utilization` when "
  305. "initializing the engine.")
  306. max_seq_len = block_size * num_gpu_blocks
  307. logger.info(f"Maximum sequence length allowed in the cache: "
  308. f"{max_seq_len}")
  309. if max_model_len > max_seq_len:
  310. raise ValueError(
  311. f"The model's max seq len ({max_model_len}) "
  312. "is larger than the maximum number of tokens that can be "
  313. f"stored in KV cache ({max_seq_len}). Try increasing "
  314. "`gpu_memory_utilization` or decreasing `max_model_len` when "
  315. "initializing the engine.")