worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. from typing import Dict, List, Tuple, Set, Optional
  5. import torch
  6. import torch.distributed
  7. from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig, LoRAConfig, DeviceConfig)
  9. from aphrodite.common.utils import in_wsl
  10. from aphrodite.modeling import set_random_seed
  11. from aphrodite.modeling.megatron import cupy_utils
  12. from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
  13. )
  14. from aphrodite.modeling.megatron.custom_all_reduce import init_custom_ar
  15. from aphrodite.modeling.megatron.parallel_state import (
  16. ensure_model_parallel_initialized)
  17. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  18. from aphrodite.task_handler.cache_engine import CacheEngine
  19. from aphrodite.task_handler.model_runner import ModelRunner
  20. from aphrodite.lora.request import LoRARequest
  21. from aphrodite.common.utils import is_hip
  22. class Worker:
  23. """A worker class that executes (a partition of) the model on a GPU.
  24. Each worker is associated with a single GPU. The worker is responsible for
  25. maintaining the KV cache and executing the model on the GPU. In case of
  26. distributed inference, each worker is assigned a partition of the model.
  27. """
  28. def __init__(
  29. self,
  30. model_config: ModelConfig,
  31. parallel_config: ParallelConfig,
  32. scheduler_config: SchedulerConfig,
  33. device_config: DeviceConfig,
  34. local_rank: int,
  35. rank: int,
  36. distributed_init_method: str,
  37. lora_config: Optional[LoRAConfig] = None,
  38. kv_cache_dtype: Optional[str] = "auto",
  39. is_driver_worker: bool = False,
  40. ) -> None:
  41. self.model_config = model_config
  42. self.parallel_config = parallel_config
  43. self.scheduler_config = scheduler_config
  44. self.device_config = device_config
  45. self.local_rank = local_rank
  46. self.rank = rank
  47. self.distributed_init_method = distributed_init_method
  48. self.lora_config = lora_config
  49. self.is_driver_worker = is_driver_worker
  50. if self.is_driver_worker:
  51. assert self.rank == 0, "The driver worker must have rank 0."
  52. self.model_runner = ModelRunner(model_config,
  53. parallel_config,
  54. scheduler_config,
  55. device_config,
  56. lora_config=self.lora_config,
  57. kv_cache_dtype=kv_cache_dtype,
  58. is_driver_worker=is_driver_worker)
  59. # Uninitialized cache engine. Will be initialized by
  60. # self.init_cache_engine().
  61. self.cache_config = None
  62. self.cache_engine = None
  63. self.cache_events = None
  64. self.gpu_cache = None
  65. def init_model(self, cupy_port: Optional[int] = None) -> None:
  66. if self.device_config.device.type == "cuda":
  67. # torch.distributed.all_reduce does not free the input tensor until
  68. # the synchronization point. This causes the memory usage to grow
  69. # as the number of all_reduce calls increases. This env var disables
  70. # this behavior.
  71. # Related issue:
  72. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  73. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  74. # This env var set by Ray causes exceptions with graph building.
  75. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  76. # Patch for torch.cuda.is_available() unexpected error in WSL; always call torch.cuda.device_count() before initialising device
  77. if in_wsl():
  78. torch.cuda.device_count()
  79. self.device = torch.device(f"cuda:{self.local_rank}")
  80. torch.cuda.set_device(self.device)
  81. _check_if_gpu_supports_dtype(self.model_config.dtype)
  82. torch.cuda.empty_cache()
  83. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  84. else:
  85. raise RuntimeError(
  86. f"Not support device type: {self.device_config.device}")
  87. # Initialize the distributed environment.
  88. init_distributed_environment(self.parallel_config, self.rank,
  89. cupy_port, self.distributed_init_method)
  90. if not self.parallel_config.disable_custom_all_reduce:
  91. init_custom_ar()
  92. # Initialize the model.
  93. set_random_seed(self.model_config.seed)
  94. def load_model(self):
  95. self.model_runner.load_model()
  96. @torch.inference_mode()
  97. def profile_num_available_blocks(
  98. self,
  99. block_size: int,
  100. gpu_memory_utilization: float,
  101. cpu_swap_space: int,
  102. cache_dtype: str,
  103. ) -> Tuple[int, int]:
  104. """Profiles the peak memory usage of the model and returns the maximum
  105. number of GPU and CPU cache blocks that can be allocated.
  106. Args:
  107. block_size: The size of the cache block.
  108. gpu_memory_utilization: The fraction of the total GPU memory to use.
  109. cpu_swap_space: The size of the CPU swap space in bytes.
  110. """
  111. # Profile the memory usage of the model and get the maximum number of
  112. # cache blocks that can be allocated with the remaining free memory.
  113. torch.cuda.empty_cache()
  114. # Execute a forward pass with dummy inputs to profile the memory usage
  115. # of the model.
  116. self.model_runner.profile_run()
  117. # Calculate the number of blocks that can be allocated with the
  118. # profiled peak memory.
  119. torch.cuda.synchronize()
  120. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  121. # NOTE: Here we assume that the other processes using the same
  122. # GPU did not change their memory usage during the profiling.
  123. peak_memory = self.init_gpu_memory - free_gpu_memory
  124. cache_block_size = CacheEngine.get_cache_block_size(
  125. block_size, cache_dtype, self.model_config, self.parallel_config)
  126. num_gpu_blocks = int(
  127. (total_gpu_memory * gpu_memory_utilization - peak_memory) //
  128. cache_block_size)
  129. num_cpu_blocks = int(cpu_swap_space // cache_block_size)
  130. num_gpu_blocks = max(num_gpu_blocks, 0)
  131. num_cpu_blocks = max(num_cpu_blocks, 0)
  132. if self.model_runner.lora_manager:
  133. self.model_runner.remove_all_loras()
  134. gc.collect()
  135. torch.cuda.empty_cache()
  136. return num_gpu_blocks, num_cpu_blocks
  137. def init_cache_engine(self, cache_config: CacheConfig) -> None:
  138. self.cache_config = cache_config
  139. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  140. self.parallel_config)
  141. self.cache_events = self.cache_engine.events
  142. self.gpu_cache = self.cache_engine.gpu_cache
  143. self.model_runner.set_block_size(self.cache_engine.block_size)
  144. def warm_up_model(self) -> None:
  145. if not self.model_config.enforce_eager:
  146. self.model_runner.capture_model(self.gpu_cache)
  147. # Reset the seed to ensure that the random state is not affected by
  148. # the model initialization and profiling.
  149. set_random_seed(self.model_config.seed)
  150. def cache_swap(
  151. self,
  152. blocks_to_swap_in: Dict[int, int],
  153. blocks_to_swap_out: Dict[int, int],
  154. blocks_to_copy: Dict[int, List[int]],
  155. ) -> None:
  156. # Issue cache operations.
  157. issued_cache_op = False
  158. if blocks_to_swap_in:
  159. self.cache_engine.swap_in(blocks_to_swap_in)
  160. issued_cache_op = True
  161. if blocks_to_swap_out:
  162. self.cache_engine.swap_out(blocks_to_swap_out)
  163. issued_cache_op = True
  164. if blocks_to_copy:
  165. self.cache_engine.copy(blocks_to_copy)
  166. issued_cache_op = True
  167. cache_events = self.cache_events if issued_cache_op else None
  168. # Wait for cache operations to finish.
  169. # TODO: Profile swapping overhead and optimize if needed.
  170. if cache_events is not None:
  171. for event in cache_events: # pylint: disable=not-an-iterable
  172. event.wait()
  173. @torch.inference_mode()
  174. def execute_model(
  175. self,
  176. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
  177. blocks_to_swap_in: Optional[Dict[int, int]] = None,
  178. blocks_to_swap_out: Optional[Dict[int, int]] = None,
  179. blocks_to_copy: Optional[Dict[int, List[int]]] = None,
  180. ) -> Optional[SamplerOutput]:
  181. if self.is_driver_worker:
  182. assert seq_group_metadata_list is not None
  183. num_seq_groups = len(seq_group_metadata_list)
  184. assert blocks_to_swap_in is not None
  185. assert blocks_to_swap_out is not None
  186. assert blocks_to_copy is not None
  187. data = {
  188. "num_seq_groups": num_seq_groups,
  189. "blocks_to_swap_in": blocks_to_swap_in,
  190. "blocks_to_swap_out": blocks_to_swap_out,
  191. "blocks_to_copy": blocks_to_copy,
  192. }
  193. broadcast_tensor_dict(data, src=0)
  194. else:
  195. data = broadcast_tensor_dict(src=0)
  196. num_seq_groups = data["num_seq_groups"]
  197. blocks_to_swap_in = data["blocks_to_swap_in"]
  198. blocks_to_swap_out = data["blocks_to_swap_out"]
  199. blocks_to_copy = data["blocks_to_copy"]
  200. self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
  201. # If there is no input, we don't need to execute the model.
  202. if num_seq_groups == 0:
  203. return {}
  204. output = self.model_runner.execute_model(seq_group_metadata_list,
  205. self.gpu_cache)
  206. return output
  207. def add_lora(self, lora_request: LoRARequest) -> bool:
  208. return self.model_runner.add_lora(lora_request)
  209. def remove_lora(self, lora_id: int) -> bool:
  210. return self.model_runner.remove_lora(lora_id)
  211. def list_loras(self) -> Set[int]:
  212. return self.model_runner.list_loras()
  213. def init_distributed_environment(
  214. parallel_config: ParallelConfig,
  215. rank: int,
  216. cupy_port: Optional[int],
  217. distributed_init_method: Optional[str] = None,
  218. ) -> None:
  219. """Initialize the distributed environment."""
  220. if torch.distributed.is_initialized():
  221. torch_world_size = torch.distributed.get_world_size()
  222. if torch_world_size != parallel_config.world_size:
  223. raise RuntimeError(
  224. "torch.distributed is already initialized but the torch world "
  225. "size does not match parallel_config.world_size "
  226. f"({torch_world_size} vs. {parallel_config.world_size}).")
  227. elif not distributed_init_method:
  228. raise ValueError(
  229. "distributed_init_method must be set if torch.distributed "
  230. "is not already initialized")
  231. else:
  232. torch.distributed.init_process_group(
  233. backend="nccl",
  234. world_size=parallel_config.world_size,
  235. rank=rank,
  236. init_method=distributed_init_method,
  237. )
  238. if cupy_utils.is_initialized():
  239. cupy_world_size = cupy_utils.get_world_size()
  240. if cupy_world_size != parallel_config.world_size:
  241. raise RuntimeError(
  242. "cupy.distributed is already initialized but the cupy world "
  243. "size does not match parallel_config.world_size "
  244. f"({cupy_world_size} vs. {parallel_config.world_size}).")
  245. elif (parallel_config.world_size > 1 and cupy_port is not None
  246. and not is_hip()):
  247. # NOTE: We don't initialize CuPy process group when world size
  248. # is 1.
  249. # TODO: Support multi-node connection.
  250. cupy_utils.init_process_group(
  251. world_size=parallel_config.world_size,
  252. rank=rank,
  253. host="localhost",
  254. port=cupy_port,
  255. )
  256. # A small all_reduce for warmup.
  257. torch.distributed.all_reduce(torch.zeros(1).cuda())
  258. if cupy_utils.is_initialized():
  259. cupy_utils.all_reduce(torch.zeros(1).cuda())
  260. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  261. parallel_config.pipeline_parallel_size)
  262. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  263. # Check if the GPU supports the dtype.
  264. if torch_dtype == torch.bfloat16:
  265. compute_capability = torch.cuda.get_device_capability()
  266. if compute_capability[0] < 8:
  267. gpu_name = torch.cuda.get_device_name()
  268. raise ValueError(
  269. "Bfloat16 is only supported on GPUs with compute capability "
  270. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  271. f"{compute_capability[0]}.{compute_capability[1]}. "
  272. "You can use float16 instead by explicitly setting the"
  273. "`dtype` flag in CLI, for example: --dtype=half.")