worker.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. from typing import List, Optional, Set, Tuple, Type
  5. import torch
  6. import torch.distributed
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  9. LoRAConfig, ModelConfig, ParallelConfig,
  10. SchedulerConfig, SpeculativeConfig,
  11. VisionLanguageConfig)
  12. from aphrodite.common.sequence import ExecuteModelRequest
  13. from aphrodite.distributed import (ensure_model_parallel_initialized,
  14. init_distributed_environment,
  15. set_custom_all_reduce)
  16. from aphrodite.lora.request import LoRARequest
  17. from aphrodite.modeling import set_random_seed
  18. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  19. from aphrodite.platforms import current_platform
  20. from aphrodite.task_handler.cache_engine import CacheEngine
  21. from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
  22. from aphrodite.task_handler.model_runner import GPUModelRunnerBase, ModelRunner
  23. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  24. WorkerInput)
  25. class Worker(LocalOrDistributedWorkerBase):
  26. """A worker class that executes (a partition of) the model on a GPU.
  27. Each worker is associated with a single GPU. The worker is responsible for
  28. maintaining the KV cache and executing the model on the GPU. In case of
  29. distributed inference, each worker is assigned a partition of the model.
  30. """
  31. def __init__(
  32. self,
  33. model_config: ModelConfig,
  34. parallel_config: ParallelConfig,
  35. scheduler_config: SchedulerConfig,
  36. device_config: DeviceConfig,
  37. cache_config: CacheConfig,
  38. load_config: LoadConfig,
  39. local_rank: int,
  40. rank: int,
  41. distributed_init_method: str,
  42. lora_config: Optional[LoRAConfig] = None,
  43. vision_language_config: Optional[VisionLanguageConfig] = None,
  44. speculative_config: Optional[SpeculativeConfig] = None,
  45. is_driver_worker: bool = False,
  46. model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
  47. ) -> None:
  48. self.model_config = model_config
  49. self.parallel_config = parallel_config
  50. self.scheduler_config = scheduler_config
  51. self.device_config = device_config
  52. self.cache_config = cache_config
  53. self.local_rank = local_rank
  54. self.rank = rank
  55. self.distributed_init_method = distributed_init_method
  56. self.lora_config = lora_config
  57. self.load_config = load_config
  58. self.is_driver_worker = is_driver_worker
  59. if parallel_config and is_driver_worker:
  60. assert rank % parallel_config.tensor_parallel_size == 0, \
  61. "Driver worker should be rank 0 of tensor parallel group."
  62. if self.model_config.trust_remote_code:
  63. # note: lazy import to avoid importing torch before initializing
  64. from aphrodite.common.utils import init_cached_hf_modules
  65. init_cached_hf_modules()
  66. self.vision_language_config = vision_language_config
  67. if self.vision_language_config:
  68. assert not self.lora_config, (
  69. "To be tested: vision language model with LoRA settings.")
  70. # Return hidden states from target model if the draft model is an
  71. # mlp_speculator
  72. speculative_args = {} if speculative_config is None \
  73. or (speculative_config.draft_model_config.model ==
  74. model_config.model) \
  75. or (speculative_config.draft_model_config.hf_config.model_type !=
  76. "mlp_speculator") else {"return_hidden_states": True}
  77. ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
  78. if model_runner_cls is not None:
  79. ModelRunnerClass = model_runner_cls
  80. elif self.model_config.embedding_mode:
  81. ModelRunnerClass = EmbeddingModelRunner
  82. self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
  83. model_config,
  84. parallel_config,
  85. scheduler_config,
  86. device_config,
  87. cache_config,
  88. load_config=load_config,
  89. lora_config=self.lora_config,
  90. kv_cache_dtype=self.cache_config.cache_dtype,
  91. is_driver_worker=is_driver_worker,
  92. vision_language_config=vision_language_config,
  93. **speculative_args,
  94. )
  95. # Uninitialized cache engine. Will be initialized by
  96. # initialize_cache.
  97. self.cache_engine: List[CacheEngine]
  98. # Initialize gpu_cache as embedding models don't initialize kv_caches
  99. self.gpu_cache: Optional[List[List[torch.tensor]]] = None
  100. def init_device(self) -> None:
  101. if self.device_config.device.type == "cuda":
  102. # torch.distributed.all_reduce does not free the input tensor until
  103. # the synchronization point. This causes the memory usage to grow
  104. # as the number of all_reduce calls increases. This env var disables
  105. # this behavior.
  106. # Related issue:
  107. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  108. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  109. # This env var set by Ray causes exceptions with graph building.
  110. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  111. self.device = torch.device(f"cuda:{self.local_rank}")
  112. torch.cuda.set_device(self.device)
  113. _check_if_gpu_supports_dtype(self.model_config.dtype)
  114. torch.cuda.empty_cache()
  115. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  116. else:
  117. raise RuntimeError(
  118. f"Not support device type: {self.device_config.device}")
  119. # Initialize the distributed environment.
  120. init_worker_distributed_environment(self.parallel_config, self.rank,
  121. self.distributed_init_method,
  122. self.local_rank)
  123. # Set random seed.
  124. set_random_seed(self.model_config.seed)
  125. def load_model(self):
  126. self.model_runner.load_model()
  127. def save_sharded_state(
  128. self,
  129. path: str,
  130. pattern: Optional[str] = None,
  131. max_size: Optional[int] = None,
  132. ) -> None:
  133. self.model_runner.save_sharded_state(
  134. path,
  135. pattern=pattern,
  136. max_size=max_size,
  137. )
  138. def save_tensorized_model(
  139. self,
  140. tensorizer_config: TensorizerConfig,
  141. ) -> None:
  142. self.model_runner.save_tensorized_model(
  143. tensorizer_config=tensorizer_config, )
  144. @torch.inference_mode()
  145. def determine_num_available_blocks(self) -> Tuple[int, int]:
  146. """Profiles the peak memory usage of the model to determine how many
  147. KV blocks may be allocated without OOMs.
  148. The engine will first conduct a profiling of the existing memory usage.
  149. Then, it calculate the maximum possible number of GPU and CPU blocks
  150. that can be allocated with the remaining free memory.
  151. .. tip::
  152. You may limit the usage of GPU memory
  153. by adjusting the `gpu_memory_utilization` parameter.
  154. """
  155. # Profile the memory usage of the model and get the maximum number of
  156. # cache blocks that can be allocated with the remaining free memory.
  157. torch.cuda.empty_cache()
  158. # Execute a forward pass with dummy inputs to profile the memory usage
  159. # of the model.
  160. self.model_runner.profile_run()
  161. # Calculate the number of blocks that can be allocated with the
  162. # profiled peak memory.
  163. torch.cuda.synchronize()
  164. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  165. # NOTE: Here we assume that the other processes using the same
  166. # GPU did not change their memory usage during the profiling.
  167. peak_memory = self.init_gpu_memory - free_gpu_memory
  168. assert peak_memory > 0, (
  169. "Error in memory profiling. This happens when the GPU memory was "
  170. "not properly cleaned up before initializing Aphrodite.")
  171. cache_block_size = self.get_cache_block_size_bytes()
  172. num_gpu_blocks = int(
  173. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  174. peak_memory) // cache_block_size)
  175. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  176. cache_block_size)
  177. num_gpu_blocks = max(num_gpu_blocks, 0)
  178. num_cpu_blocks = max(num_cpu_blocks, 0)
  179. if self.model_runner.lora_manager:
  180. self.model_runner.remove_all_loras()
  181. gc.collect()
  182. torch.cuda.empty_cache()
  183. return num_gpu_blocks, num_cpu_blocks
  184. def initialize_cache(self, num_gpu_blocks: int,
  185. num_cpu_blocks: int) -> None:
  186. """Allocate GPU and CPU KV cache with the specified number of blocks.
  187. This also warms up the model, which may record CUDA graphs.
  188. """
  189. raise_if_cache_size_invalid(num_gpu_blocks,
  190. self.cache_config.block_size,
  191. self.model_config.max_model_len)
  192. self.cache_config.num_gpu_blocks = num_gpu_blocks
  193. self.cache_config.num_cpu_blocks = num_cpu_blocks
  194. self._init_cache_engine()
  195. self._warm_up_model()
  196. def _init_cache_engine(self):
  197. assert self.cache_config.num_gpu_blocks is not None
  198. self.cache_engine = [
  199. CacheEngine(self.cache_config, self.model_config,
  200. self.parallel_config, self.device_config)
  201. for _ in range(self.parallel_config.pipeline_parallel_size)
  202. ]
  203. self.gpu_cache = [
  204. self.cache_engine[ve].gpu_cache
  205. for ve in range(self.parallel_config.pipeline_parallel_size)
  206. ]
  207. def _warm_up_model(self) -> None:
  208. if not self.model_config.enforce_eager:
  209. self.model_runner.capture_model(self.gpu_cache)
  210. # Reset the seed to ensure that the random state is not affected by
  211. # the model initialization and profiling.
  212. set_random_seed(self.model_config.seed)
  213. @property
  214. def do_metadata_broadcast(self) -> bool:
  215. return self.parallel_config.tensor_parallel_size > 1
  216. @property
  217. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  218. return self.gpu_cache
  219. @torch.inference_mode()
  220. def prepare_worker_input(
  221. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  222. virtual_engine = execute_model_req.virtual_engine
  223. num_seq_groups = len(execute_model_req.seq_group_metadata_list)
  224. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
  225. # they contain parameters to launch cudamemcpyasync.
  226. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
  227. device="cpu",
  228. dtype=torch.int64).view(-1, 2)
  229. blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
  230. device="cpu",
  231. dtype=torch.int64).view(-1, 2)
  232. # `blocks_to_copy` is a gpu tensor. The src and tgt of
  233. # blocks to copy are in the same device, and `blocks_to_copy`
  234. # can be used directly within cuda kernels.
  235. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  236. device=self.device,
  237. dtype=torch.int64).view(-1, 2)
  238. return WorkerInput(num_seq_groups=num_seq_groups,
  239. blocks_to_swap_in=blocks_to_swap_in,
  240. blocks_to_swap_out=blocks_to_swap_out,
  241. blocks_to_copy=blocks_to_copy,
  242. virtual_engine=virtual_engine)
  243. @torch.inference_mode()
  244. def execute_worker(self, worker_input: WorkerInput) -> None:
  245. virtual_engine = worker_input.virtual_engine
  246. # Issue cache operations.
  247. if (worker_input.blocks_to_swap_in is not None
  248. and worker_input.blocks_to_swap_in.numel() > 0):
  249. self.cache_engine[virtual_engine].swap_in(
  250. worker_input.blocks_to_swap_in)
  251. if (worker_input.blocks_to_swap_out is not None
  252. and worker_input.blocks_to_swap_out.numel() > 0):
  253. self.cache_engine[virtual_engine].swap_out(
  254. worker_input.blocks_to_swap_out)
  255. if (worker_input.blocks_to_copy is not None
  256. and worker_input.blocks_to_copy.numel() > 0):
  257. self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
  258. def add_lora(self, lora_request: LoRARequest) -> bool:
  259. return self.model_runner.add_lora(lora_request)
  260. def remove_lora(self, lora_id: int) -> bool:
  261. return self.model_runner.remove_lora(lora_id)
  262. def pin_lora(self, lora_id: int) -> bool:
  263. return self.model_runner.pin_lora(lora_id)
  264. def list_loras(self) -> Set[int]:
  265. return self.model_runner.list_loras()
  266. @property
  267. def max_model_len(self) -> int:
  268. return self.model_config.max_model_len
  269. @property
  270. def vocab_size(self) -> int:
  271. return self.model_runner.vocab_size
  272. def get_cache_block_size_bytes(self) -> int:
  273. """Get the size of the KV cache block size in bytes.
  274. """
  275. return CacheEngine.get_cache_block_size(self.cache_config,
  276. self.model_config,
  277. self.parallel_config)
  278. def init_worker_distributed_environment(
  279. parallel_config: ParallelConfig,
  280. rank: int,
  281. distributed_init_method: Optional[str] = None,
  282. local_rank: int = -1,
  283. ) -> None:
  284. """Initialize the distributed environment."""
  285. set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
  286. init_distributed_environment(parallel_config.world_size, rank,
  287. distributed_init_method, local_rank)
  288. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  289. parallel_config.pipeline_parallel_size)
  290. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  291. # Check if the GPU supports the dtype.
  292. if torch_dtype == torch.bfloat16:
  293. compute_capability = current_platform.get_device_capability()
  294. if compute_capability[0] < 8:
  295. gpu_name = torch.cuda.get_device_name()
  296. raise ValueError(
  297. "Bfloat16 is only supported on GPUs with compute capability "
  298. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  299. f"{compute_capability[0]}.{compute_capability[1]}. "
  300. "You can use float16 instead by explicitly setting the"
  301. "`dtype` flag in CLI, for example: --dtype=half.")
  302. def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
  303. max_model_len) -> None:
  304. if num_gpu_blocks <= 0:
  305. raise ValueError("No available memory for the cache blocks. "
  306. "Try increasing `gpu_memory_utilization` when "
  307. "initializing the engine.")
  308. max_seq_len = block_size * num_gpu_blocks
  309. logger.info(f"Maximum sequence length allowed in the cache: "
  310. f"{max_seq_len}")
  311. if max_model_len > max_seq_len:
  312. original_max_model_len = max_model_len
  313. max_model_len = max_seq_len
  314. # raise ValueError(
  315. # f"The model's max seq len ({max_model_len}) "
  316. # "is larger than the maximum number of tokens that can be "
  317. # f"stored in KV cache ({max_seq_len}). Try increasing "
  318. # "`gpu_memory_utilization` or decreasing `max_model_len` when "
  319. # "initializing the engine.")
  320. # set the max_model_len to the max_seq_len, but raise a logger.error
  321. # so the user is made aware of this
  322. logger.error(
  323. f"The model's max seq len ({original_max_model_len}) "
  324. "is larger than the maximum number of tokens that can be "
  325. f"stored in KV cache ({max_seq_len}). "
  326. "Try increasing "
  327. "`gpu_memory_utilization`, setting "
  328. "`--enable-chunked-prefill`, or `--kv-cache-dtype fp8` "
  329. "when initializing the engine. The last two are currently "
  330. "mutually exclusive.\n"
  331. f"Forcing max_model_len to {max_seq_len}.")