worker.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. """A GPU worker class."""
  2. import os
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. import torch.distributed
  6. from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
  7. SchedulerConfig)
  8. from aphrodite.modeling import set_random_seed
  9. from aphrodite.modeling.megatron.parallel_state import (
  10. initialize_model_parallel)
  11. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  12. from aphrodite.task_handler.cache_engine import CacheEngine
  13. from aphrodite.task_handler.model_runner import ModelRunner
  14. class Worker:
  15. """A worker class that executes (a partition of) the model on a GPU.
  16. Each worker is associated with a single GPU. The worker is responsible for
  17. maintaining the KV cache and executing the model on the GPU. In case of
  18. distributed inference, each worker is assigned a partition of the model.
  19. """
  20. def __init__(
  21. self,
  22. model_config: ModelConfig,
  23. parallel_config: ParallelConfig,
  24. scheduler_config: SchedulerConfig,
  25. rank: Optional[int] = None,
  26. distributed_init_method: Optional[str] = None,
  27. ) -> None:
  28. self.model_config = model_config
  29. self.parallel_config = parallel_config
  30. self.scheduler_config = scheduler_config
  31. self.rank = rank
  32. self.distributed_init_method = distributed_init_method
  33. self.model_runner = ModelRunner(model_config, parallel_config,
  34. scheduler_config)
  35. # Uninitialized cache engine. Will be initialized by
  36. # self.init_cache_engine().
  37. self.cache_config = None
  38. self.cache_engine = None
  39. self.cache_events = None
  40. self.gpu_cache = None
  41. def init_model(self) -> None:
  42. # torch.distributed.all_reduce does not free the input tensor until
  43. # the synchronization point. This causes the memory usage to grow
  44. # as the number of all_reduce calls increases. This env var disables
  45. # this behaviour.
  46. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  47. # This env var set by Ray causes exceptions with graph building.
  48. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  49. # Env vars will be set by Ray.
  50. self.rank = self.rank if self.rank is not None else int(
  51. os.getenv("RANK", "-1"))
  52. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  53. self.device = torch.device(f"cuda:{local_rank}")
  54. if self.rank < 0:
  55. raise ValueError("Invalid or unspecified rank.")
  56. torch.cuda.set_device(self.device)
  57. _check_if_gpu_supports_dtype(self.model_config.dtype)
  58. # Initialize the distributed environment.
  59. _init_distributed_environment(self.parallel_config, self.rank,
  60. self.distributed_init_method)
  61. # Initialize the model.
  62. set_random_seed(self.model_config.seed)
  63. def load_model(self):
  64. self.model_runner.load_model()
  65. @torch.inference_mode()
  66. def profile_num_available_blocks(
  67. self,
  68. block_size: int,
  69. gpu_memory_utilization: float,
  70. cpu_swap_space: int,
  71. cache_dtype: torch.dtype,
  72. ) -> Tuple[int, int]:
  73. # Profile the memory usage of the model and get the maximum number of
  74. # cache blocks that can be allocated with the remaining free memory.
  75. torch.cuda.empty_cache()
  76. # Execute a forward pass with dummy inputs to profile the memory usage
  77. # of the model.
  78. self.model_runner.profile_run()
  79. # Calculate the number of blocks that can be allocated with the
  80. # profiled peak memory.
  81. torch.cuda.synchronize()
  82. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  83. peak_memory = total_gpu_memory - free_gpu_memory
  84. cache_block_size = CacheEngine.get_cache_block_size(
  85. block_size, cache_dtype, self.model_config, self.parallel_config)
  86. num_gpu_blocks = int(
  87. (total_gpu_memory * gpu_memory_utilization - peak_memory) //
  88. cache_block_size)
  89. num_cpu_blocks = int(cpu_swap_space // cache_block_size)
  90. num_gpu_blocks = max(num_gpu_blocks, 0)
  91. num_cpu_blocks = max(num_cpu_blocks, 0)
  92. torch.cuda.empty_cache()
  93. return num_gpu_blocks, num_cpu_blocks
  94. def init_cache_engine(self, cache_config: CacheConfig) -> None:
  95. self.cache_config = cache_config
  96. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  97. self.parallel_config)
  98. self.cache_events = self.cache_engine.events
  99. self.gpu_cache = self.cache_engine.gpu_cache
  100. self.model_runner.set_block_size(self.cache_engine.block_size)
  101. def warm_up_model(self) -> None:
  102. if not self.model_config.enforce_eager:
  103. self.model_runner.capture_model(self.gpu_cache)
  104. # Reset the seed to ensure that the random state is not affected by
  105. # the model initialization and profiling.
  106. set_random_seed(self.model_config.seed)
  107. @torch.inference_mode()
  108. def execute_model(
  109. self,
  110. seq_group_metadata_list: List[SequenceGroupMetadata],
  111. blocks_to_swap_in: Dict[int, int],
  112. blocks_to_swap_out: Dict[int, int],
  113. blocks_to_copy: Dict[int, List[int]],
  114. ) -> SamplerOutput:
  115. # Issue cache operations.
  116. issued_cache_op = False
  117. if blocks_to_swap_in:
  118. self.cache_engine.swap_in(blocks_to_swap_in)
  119. issued_cache_op = True
  120. if blocks_to_swap_out:
  121. self.cache_engine.swap_out(blocks_to_swap_out)
  122. issued_cache_op = True
  123. if blocks_to_copy:
  124. self.cache_engine.copy(blocks_to_copy)
  125. issued_cache_op = True
  126. cache_events = self.cache_events if issued_cache_op else None
  127. # Wati for cache operations to finish.
  128. # TODO: Profile swapping overhead and optimize if needed.
  129. if cache_events is not None:
  130. for event in cache_events: # pylint: disable=not-an-iterable
  131. event.wait()
  132. # If there is no input, we don't need to execute the model.
  133. if not seq_group_metadata_list:
  134. return {}
  135. output = self.model_runner.execute_model(seq_group_metadata_list,
  136. self.gpu_cache)
  137. return output
  138. def _init_distributed_environment(
  139. parallel_config: ParallelConfig,
  140. rank: int,
  141. distributed_init_method: Optional[str] = None,
  142. ) -> None:
  143. """Initialize the distributed environment."""
  144. if torch.distributed.is_initialized():
  145. torch_world_size = torch.distributed.get_world_size()
  146. if torch_world_size != parallel_config.world_size:
  147. raise RuntimeError(
  148. "torch.distributed is already initialized but the torch world "
  149. "size does not match parallel_config.world_size "
  150. f"({torch_world_size} vs. {parallel_config.world_size}).")
  151. elif not distributed_init_method:
  152. raise ValueError(
  153. "distributed_init_method must be set if torch.distributed "
  154. "is not already initialized")
  155. else:
  156. torch.distributed.init_process_group(
  157. backend="nccl",
  158. world_size=parallel_config.world_size,
  159. rank=rank,
  160. init_method=distributed_init_method,
  161. )
  162. # A small all_reduce for warmup.
  163. torch.distributed.all_reduce(torch.zeros(1).cuda())
  164. initialize_model_parallel(parallel_config.tensor_parallel_size,
  165. parallel_config.pipeline_parallel_size)
  166. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  167. if torch_dtype == torch.bfloat16:
  168. compute_capability = torch.cuda.get_device_capability()
  169. if compute_capability[0] < 8:
  170. gpu_name = torch.cuda.get_device_name()
  171. raise ValueError(
  172. "Bfloat16 is only supported on GPUs with compute capability "
  173. f"of at least 8.0. You {gpu_name} GPU has compute capability "
  174. f"{compute_capability[0]}.{compute_capability[1]}. Please "
  175. "use the `--dtype float16` argument when launching the engine."
  176. )