worker.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. import time
  5. from typing import List, Optional, Set, Tuple, Type
  6. import torch
  7. import torch.distributed
  8. from loguru import logger
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, MultiModalConfig,
  11. ParallelConfig, PromptAdapterConfig,
  12. SchedulerConfig, SpeculativeConfig)
  13. from aphrodite.common.sequence import ExecuteModelRequest
  14. from aphrodite.common.utils import (is_embedding_model_config,
  15. is_encoder_decoder_model_config)
  16. from aphrodite.distributed import (ensure_model_parallel_initialized,
  17. get_tensor_model_parallel_rank,
  18. get_tensor_model_parallel_world_size,
  19. init_distributed_environment,
  20. set_custom_all_reduce)
  21. from aphrodite.lora.request import LoRARequest
  22. from aphrodite.modeling import set_random_seed
  23. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  24. from aphrodite.platforms import current_platform
  25. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  26. from aphrodite.task_handler.cache_engine import CacheEngine
  27. from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
  28. from aphrodite.task_handler.enc_dec_model_runner import (
  29. EncoderDecoderModelRunner)
  30. from aphrodite.task_handler.model_runner import GPUModelRunnerBase, ModelRunner
  31. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  32. WorkerInput)
  33. class Worker(LocalOrDistributedWorkerBase):
  34. """A worker class that executes (a partition of) the model on a GPU.
  35. Each worker is associated with a single GPU. The worker is responsible for
  36. maintaining the KV cache and executing the model on the GPU. In case of
  37. distributed inference, each worker is assigned a partition of the model.
  38. """
  39. def __init__(
  40. self,
  41. model_config: ModelConfig,
  42. parallel_config: ParallelConfig,
  43. scheduler_config: SchedulerConfig,
  44. device_config: DeviceConfig,
  45. cache_config: CacheConfig,
  46. load_config: LoadConfig,
  47. local_rank: int,
  48. rank: int,
  49. distributed_init_method: str,
  50. lora_config: Optional[LoRAConfig] = None,
  51. multimodal_config: Optional[MultiModalConfig] = None,
  52. speculative_config: Optional[SpeculativeConfig] = None,
  53. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  54. is_driver_worker: bool = False,
  55. model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
  56. ) -> None:
  57. self.model_config = model_config
  58. self.parallel_config = parallel_config
  59. self.parallel_config.rank = rank
  60. self.scheduler_config = scheduler_config
  61. self.device_config = device_config
  62. self.cache_config = cache_config
  63. self.local_rank = local_rank
  64. self.rank = rank
  65. self.distributed_init_method = distributed_init_method
  66. self.lora_config = lora_config
  67. self.prompt_adapter_config = prompt_adapter_config
  68. self.load_config = load_config
  69. self.is_driver_worker = is_driver_worker
  70. if parallel_config and is_driver_worker:
  71. assert rank % parallel_config.tensor_parallel_size == 0, \
  72. "Driver worker should be rank 0 of tensor parallel group."
  73. if self.model_config.trust_remote_code:
  74. # note: lazy import to avoid importing torch before initializing
  75. from aphrodite.common.utils import init_cached_hf_modules
  76. init_cached_hf_modules()
  77. self.multimodal_config = multimodal_config
  78. # Return hidden states from target model if the draft model is an
  79. # mlp_speculator
  80. speculative_args = {} if speculative_config is None \
  81. or (speculative_config.draft_model_config.model ==
  82. model_config.model) \
  83. or (speculative_config.draft_model_config.hf_config.model_type
  84. not in ["medusa", "mlp_speculator"]) \
  85. else {"return_hidden_states": True}
  86. ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
  87. if model_runner_cls is not None:
  88. ModelRunnerClass = model_runner_cls
  89. elif self._is_embedding_model():
  90. ModelRunnerClass = EmbeddingModelRunner
  91. elif self._is_encoder_decoder_model():
  92. ModelRunnerClass = EncoderDecoderModelRunner
  93. self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
  94. model_config,
  95. parallel_config,
  96. scheduler_config,
  97. device_config,
  98. cache_config,
  99. load_config=load_config,
  100. lora_config=self.lora_config,
  101. kv_cache_dtype=self.cache_config.cache_dtype,
  102. is_driver_worker=is_driver_worker,
  103. prompt_adapter_config=prompt_adapter_config,
  104. multimodal_config=multimodal_config,
  105. tp_rank=self.rank,
  106. **speculative_args,
  107. )
  108. # Uninitialized cache engine. Will be initialized by
  109. # initialize_cache.
  110. self.cache_engine: List[CacheEngine]
  111. # Initialize gpu_cache as embedding models don't initialize kv_caches
  112. self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
  113. def _is_encoder_decoder_model(self):
  114. return is_encoder_decoder_model_config(self.model_config)
  115. def _is_embedding_model(self):
  116. return is_embedding_model_config(self.model_config)
  117. def init_device(self) -> None:
  118. if self.device_config.device.type == "cuda":
  119. # torch.distributed.all_reduce does not free the input tensor until
  120. # the synchronization point. This causes the memory usage to grow
  121. # as the number of all_reduce calls increases. This env var disables
  122. # this behavior.
  123. # Related issue:
  124. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  125. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  126. # This env var set by Ray causes exceptions with graph building.
  127. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  128. self.device = torch.device(f"cuda:{self.local_rank}")
  129. torch.cuda.set_device(self.device)
  130. _check_if_gpu_supports_dtype(self.model_config.dtype)
  131. torch.cuda.empty_cache()
  132. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  133. else:
  134. raise RuntimeError(
  135. f"Not support device type: {self.device_config.device}")
  136. # Initialize the distributed environment.
  137. init_worker_distributed_environment(self.parallel_config, self.rank,
  138. self.distributed_init_method,
  139. self.local_rank)
  140. # Set random seed.
  141. set_random_seed(self.model_config.seed)
  142. def load_model(self):
  143. self.model_runner.load_model()
  144. def save_sharded_state(
  145. self,
  146. path: str,
  147. pattern: Optional[str] = None,
  148. max_size: Optional[int] = None,
  149. ) -> None:
  150. self.model_runner.save_sharded_state(
  151. path,
  152. pattern=pattern,
  153. max_size=max_size,
  154. )
  155. def save_tensorized_model(
  156. self,
  157. tensorizer_config: TensorizerConfig,
  158. ) -> None:
  159. self.model_runner.save_tensorized_model(
  160. tensorizer_config=tensorizer_config, )
  161. @torch.inference_mode()
  162. def determine_num_available_blocks(self) -> Tuple[int, int]:
  163. """Profiles the peak memory usage of the model to determine how many
  164. KV blocks may be allocated without OOMs.
  165. The engine will first conduct a profiling of the existing memory usage.
  166. Then, it calculate the maximum possible number of GPU and CPU blocks
  167. that can be allocated with the remaining free memory.
  168. .. tip::
  169. You may limit the usage of GPU memory
  170. by adjusting the `gpu_memory_utilization` parameter.
  171. """
  172. # Profile the memory usage of the model and get the maximum number of
  173. # cache blocks that can be allocated with the remaining free memory.
  174. torch.cuda.empty_cache()
  175. tp_rank = get_tensor_model_parallel_rank()
  176. world_size = get_tensor_model_parallel_world_size()
  177. # Execute a forward pass with dummy inputs to profile the memory usage
  178. # of the model.
  179. start = time.time()
  180. self.model_runner.profile_run()
  181. end = time.time()
  182. if tp_rank == 0:
  183. logger.info(f"Model profiling took {end - start:.2f} seconds.")
  184. # Calculate the number of blocks that can be allocated with the
  185. # profiled peak memory.
  186. torch.cuda.synchronize()
  187. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  188. # NOTE: Here we assume that the other processes using the same
  189. # GPU did not change their memory usage during the profiling.
  190. peak_memory = self.init_gpu_memory - free_gpu_memory
  191. weights_memory = self.model_runner.get_model_memory_usage()
  192. kv_cache_memory = peak_memory - weights_memory
  193. # log peak memory usage in GB
  194. if world_size > 1:
  195. if tp_rank == 0:
  196. logger.info(f"KV Cache memory usage for "
  197. f"{self.model_config.max_model_len} tokens: "
  198. f"{(kv_cache_memory) / 1e9:.2f} x {world_size} = "
  199. f"{(kv_cache_memory * world_size) / 1e9:.2f} GB")
  200. else:
  201. logger.info(f"KV Cache memory usage for "
  202. f"{self.model_config.max_model_len} tokens: "
  203. f"{(kv_cache_memory) / 1e9:.2f} GB")
  204. assert peak_memory > 0, (
  205. "Error in memory profiling. This happens when the GPU memory was "
  206. "not properly cleaned up before initializing Aphrodite.")
  207. cache_block_size = self.get_cache_block_size_bytes()
  208. num_gpu_blocks = int(
  209. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  210. peak_memory) // cache_block_size)
  211. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  212. cache_block_size)
  213. num_gpu_blocks = max(num_gpu_blocks, 0)
  214. num_cpu_blocks = max(num_cpu_blocks, 0)
  215. if self.model_runner.lora_manager:
  216. self.model_runner.remove_all_loras()
  217. gc.collect()
  218. torch.cuda.empty_cache()
  219. return num_gpu_blocks, num_cpu_blocks
  220. def initialize_cache(self, num_gpu_blocks: int,
  221. num_cpu_blocks: int) -> None:
  222. """Allocate GPU and CPU KV cache with the specified number of blocks.
  223. This also warms up the model, which may record CUDA graphs.
  224. """
  225. raise_if_cache_size_invalid(num_gpu_blocks,
  226. self.cache_config.block_size,
  227. self.model_config.max_model_len)
  228. self.cache_config.num_gpu_blocks = num_gpu_blocks
  229. self.cache_config.num_cpu_blocks = num_cpu_blocks
  230. self._init_cache_engine()
  231. self._warm_up_model()
  232. def _init_cache_engine(self):
  233. assert self.cache_config.num_gpu_blocks is not None
  234. self.cache_engine = [
  235. CacheEngine(self.cache_config, self.model_config,
  236. self.parallel_config, self.device_config, self.rank)
  237. for _ in range(self.parallel_config.pipeline_parallel_size)
  238. ]
  239. self.gpu_cache = [
  240. self.cache_engine[ve].gpu_cache
  241. for ve in range(self.parallel_config.pipeline_parallel_size)
  242. ]
  243. def _warm_up_model(self) -> None:
  244. if not self.model_config.enforce_eager:
  245. self.model_runner.capture_model(self.gpu_cache)
  246. # Reset the seed to ensure that the random state is not affected by
  247. # the model initialization and profiling.
  248. set_random_seed(self.model_config.seed)
  249. @property
  250. def do_metadata_broadcast(self) -> bool:
  251. return self.parallel_config.tensor_parallel_size > 1
  252. @property
  253. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  254. return self.gpu_cache
  255. @torch.inference_mode()
  256. def prepare_worker_input(
  257. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  258. virtual_engine = execute_model_req.virtual_engine
  259. num_seq_groups = len(execute_model_req.seq_group_metadata_list)
  260. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
  261. # they contain parameters to launch cudamemcpyasync.
  262. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
  263. device="cpu",
  264. dtype=torch.int64).view(-1, 2)
  265. blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
  266. device="cpu",
  267. dtype=torch.int64).view(-1, 2)
  268. # `blocks_to_copy` is a gpu tensor. The src and tgt of
  269. # blocks to copy are in the same device, and `blocks_to_copy`
  270. # can be used directly within cuda kernels.
  271. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  272. device=self.device,
  273. dtype=torch.int64).view(-1, 2)
  274. return WorkerInput(num_seq_groups=num_seq_groups,
  275. blocks_to_swap_in=blocks_to_swap_in,
  276. blocks_to_swap_out=blocks_to_swap_out,
  277. blocks_to_copy=blocks_to_copy,
  278. virtual_engine=virtual_engine)
  279. @torch.inference_mode()
  280. def execute_worker(self, worker_input: WorkerInput) -> None:
  281. virtual_engine = worker_input.virtual_engine
  282. # Issue cache operations.
  283. if (worker_input.blocks_to_swap_in is not None
  284. and worker_input.blocks_to_swap_in.numel() > 0):
  285. self.cache_engine[virtual_engine].swap_in(
  286. worker_input.blocks_to_swap_in)
  287. if (worker_input.blocks_to_swap_out is not None
  288. and worker_input.blocks_to_swap_out.numel() > 0):
  289. self.cache_engine[virtual_engine].swap_out(
  290. worker_input.blocks_to_swap_out)
  291. if (worker_input.blocks_to_copy is not None
  292. and worker_input.blocks_to_copy.numel() > 0):
  293. self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
  294. def add_lora(self, lora_request: LoRARequest) -> bool:
  295. return self.model_runner.add_lora(lora_request)
  296. def remove_lora(self, lora_id: int) -> bool:
  297. return self.model_runner.remove_lora(lora_id)
  298. def pin_lora(self, lora_id: int) -> bool:
  299. return self.model_runner.pin_lora(lora_id)
  300. def list_loras(self) -> Set[int]:
  301. return self.model_runner.list_loras()
  302. def add_prompt_adapter(
  303. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  304. return self.model_runner.add_prompt_adapter(prompt_adapter_request)
  305. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  306. return self.model_runner.remove_lora(prompt_adapter_id)
  307. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  308. return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
  309. def list_prompt_adapters(self) -> Set[int]:
  310. return self.model_runner.list_prompt_adapters()
  311. @property
  312. def max_model_len(self) -> int:
  313. return self.model_config.max_model_len
  314. @property
  315. def vocab_size(self) -> int:
  316. return self.model_runner.vocab_size
  317. def get_cache_block_size_bytes(self) -> int:
  318. """Get the size of the KV cache block size in bytes.
  319. """
  320. return CacheEngine.get_cache_block_size(self.cache_config,
  321. self.model_config,
  322. self.parallel_config)
  323. def init_worker_distributed_environment(
  324. parallel_config: ParallelConfig,
  325. rank: int,
  326. distributed_init_method: Optional[str] = None,
  327. local_rank: int = -1,
  328. ) -> None:
  329. """Initialize the distributed environment."""
  330. set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
  331. init_distributed_environment(parallel_config.world_size, rank,
  332. distributed_init_method, local_rank)
  333. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  334. parallel_config.pipeline_parallel_size)
  335. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  336. # Check if the GPU supports the dtype.
  337. if torch_dtype == torch.bfloat16:
  338. compute_capability = current_platform.get_device_capability()
  339. if compute_capability[0] < 8:
  340. gpu_name = torch.cuda.get_device_name()
  341. raise ValueError(
  342. "Bfloat16 is only supported on GPUs with compute capability "
  343. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  344. f"{compute_capability[0]}.{compute_capability[1]}. "
  345. "You can use float16 instead by explicitly setting the"
  346. "`dtype` flag in CLI, for example: --dtype=half.")
  347. def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
  348. max_model_len) -> None:
  349. rank = get_tensor_model_parallel_rank()
  350. if num_gpu_blocks <= 0:
  351. raise ValueError("No available memory for the cache blocks. "
  352. "Try increasing `gpu_memory_utilization` when "
  353. "initializing the engine.")
  354. max_seq_len = block_size * num_gpu_blocks
  355. if rank == 0:
  356. logger.info(f"Maximum sequence length allowed in the cache: "
  357. f"{max_seq_len}")
  358. if max_model_len > max_seq_len:
  359. original_max_model_len = max_model_len
  360. max_model_len = max_seq_len
  361. # raise ValueError(
  362. # f"The model's max seq len ({max_model_len}) "
  363. # "is larger than the maximum number of tokens that can be "
  364. # f"stored in KV cache ({max_seq_len}). Try increasing "
  365. # "`gpu_memory_utilization` or decreasing `max_model_len` when "
  366. # "initializing the engine.")
  367. # set the max_model_len to the max_seq_len, but raise a logger.error
  368. # so the user is made aware of this
  369. logger.error(
  370. f"The model's max seq len ({original_max_model_len}) "
  371. "is larger than the maximum number of tokens that can be "
  372. f"stored in KV cache ({max_seq_len}). "
  373. "Try increasing "
  374. "`gpu_memory_utilization`, setting "
  375. "`--enable-chunked-prefill`, or `--kv-cache-dtype fp8` "
  376. "when initializing the engine. The last two are currently "
  377. "mutually exclusive.\n"
  378. f"Forcing max_model_len to {max_seq_len}.")