worker.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. import time
  5. from typing import Dict, List, Optional, Set, Tuple, Type, Union
  6. import torch
  7. import torch.distributed
  8. from loguru import logger
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, ParallelConfig,
  11. PromptAdapterConfig, SchedulerConfig,
  12. SpeculativeConfig)
  13. from aphrodite.common.sequence import (ExecuteModelRequest,
  14. IntermediateTensors,
  15. SequenceGroupMetadata,
  16. SequenceGroupMetadataDelta)
  17. from aphrodite.distributed import (ensure_model_parallel_initialized,
  18. get_tensor_model_parallel_rank,
  19. init_distributed_environment,
  20. set_custom_all_reduce)
  21. from aphrodite.lora.request import LoRARequest
  22. from aphrodite.modeling import set_random_seed
  23. from aphrodite.modeling.layers.sampler import SamplerOutput
  24. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  25. from aphrodite.platforms import current_platform
  26. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  27. from aphrodite.task_handler.cache_engine import CacheEngine
  28. from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
  29. from aphrodite.task_handler.enc_dec_model_runner import (
  30. EncoderDecoderModelRunner)
  31. from aphrodite.task_handler.model_runner import GPUModelRunnerBase, ModelRunner
  32. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  33. WorkerInput)
  34. class Worker(LocalOrDistributedWorkerBase):
  35. """A worker class that executes (a partition of) the model on a GPU.
  36. Each worker is associated with a single GPU. The worker is responsible for
  37. maintaining the KV cache and executing the model on the GPU. In case of
  38. distributed inference, each worker is assigned a partition of the model.
  39. """
  40. def __init__(
  41. self,
  42. model_config: ModelConfig,
  43. parallel_config: ParallelConfig,
  44. scheduler_config: SchedulerConfig,
  45. device_config: DeviceConfig,
  46. cache_config: CacheConfig,
  47. load_config: LoadConfig,
  48. local_rank: int,
  49. rank: int,
  50. distributed_init_method: str,
  51. lora_config: Optional[LoRAConfig] = None,
  52. speculative_config: Optional[SpeculativeConfig] = None,
  53. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  54. is_driver_worker: bool = False,
  55. model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
  56. ) -> None:
  57. self.model_config = model_config
  58. self.parallel_config = parallel_config
  59. self.parallel_config.rank = rank
  60. self.scheduler_config = scheduler_config
  61. self.device_config = device_config
  62. self.cache_config = cache_config
  63. self.local_rank = local_rank
  64. self.rank = rank
  65. self.distributed_init_method = distributed_init_method
  66. self.lora_config = lora_config
  67. self.prompt_adapter_config = prompt_adapter_config
  68. self.load_config = load_config
  69. self.is_driver_worker = is_driver_worker
  70. if parallel_config and is_driver_worker:
  71. assert rank % parallel_config.tensor_parallel_size == 0, \
  72. "Driver worker should be rank 0 of tensor parallel group."
  73. if self.model_config.trust_remote_code:
  74. # note: lazy import to avoid importing torch before initializing
  75. from aphrodite.common.utils import init_cached_hf_modules
  76. init_cached_hf_modules()
  77. # Return hidden states from target model if the draft model is an
  78. # mlp_speculator
  79. speculative_args = {} if speculative_config is None \
  80. or (speculative_config.draft_model_config.model ==
  81. model_config.model) \
  82. or (speculative_config.draft_model_config.hf_config.model_type
  83. not in ["medusa", "mlp_speculator", "eagle"]) \
  84. else {"return_hidden_states": True}
  85. ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
  86. if model_runner_cls is not None:
  87. ModelRunnerClass = model_runner_cls
  88. elif self._is_embedding_model():
  89. ModelRunnerClass = EmbeddingModelRunner
  90. elif self._is_encoder_decoder_model():
  91. ModelRunnerClass = EncoderDecoderModelRunner
  92. self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
  93. model_config,
  94. parallel_config,
  95. scheduler_config,
  96. device_config,
  97. cache_config,
  98. load_config=load_config,
  99. lora_config=self.lora_config,
  100. kv_cache_dtype=self.cache_config.cache_dtype,
  101. is_driver_worker=is_driver_worker,
  102. prompt_adapter_config=prompt_adapter_config,
  103. tp_rank=self.rank,
  104. **speculative_args,
  105. )
  106. # Uninitialized cache engine. Will be initialized by
  107. # initialize_cache.
  108. self.cache_engine: List[CacheEngine]
  109. # Initialize gpu_cache as embedding models don't initialize kv_caches
  110. self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
  111. self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
  112. def _is_encoder_decoder_model(self):
  113. return self.model_config.is_encoder_decoder_model
  114. def _is_embedding_model(self):
  115. return self.model_config.is_embedding_model
  116. def init_device(self) -> None:
  117. if self.device_config.device.type == "cuda":
  118. # torch.distributed.all_reduce does not free the input tensor until
  119. # the synchronization point. This causes the memory usage to grow
  120. # as the number of all_reduce calls increases. This env var disables
  121. # this behavior.
  122. # Related issue:
  123. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  124. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  125. # This env var set by Ray causes exceptions with graph building.
  126. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  127. self.device = torch.device(f"cuda:{self.local_rank}")
  128. torch.cuda.set_device(self.device)
  129. _check_if_gpu_supports_dtype(self.model_config.dtype)
  130. torch.cuda.empty_cache()
  131. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  132. else:
  133. raise RuntimeError(
  134. f"Not support device type: {self.device_config.device}")
  135. # Initialize the distributed environment.
  136. init_worker_distributed_environment(self.parallel_config, self.rank,
  137. self.distributed_init_method,
  138. self.local_rank)
  139. # Set random seed.
  140. set_random_seed(self.model_config.seed)
  141. def load_model(self):
  142. self.model_runner.load_model()
  143. def save_sharded_state(
  144. self,
  145. path: str,
  146. pattern: Optional[str] = None,
  147. max_size: Optional[int] = None,
  148. ) -> None:
  149. self.model_runner.save_sharded_state(
  150. path,
  151. pattern=pattern,
  152. max_size=max_size,
  153. )
  154. def save_tensorized_model(
  155. self,
  156. tensorizer_config: TensorizerConfig,
  157. ) -> None:
  158. self.model_runner.save_tensorized_model(
  159. tensorizer_config=tensorizer_config, )
  160. @torch.inference_mode()
  161. def determine_num_available_blocks(self) -> Tuple[int, int]:
  162. """Profiles the peak memory usage of the model to determine how many
  163. KV blocks may be allocated without OOMs.
  164. The engine will first conduct a profiling of the existing memory usage.
  165. Then, it calculate the maximum possible number of GPU and CPU blocks
  166. that can be allocated with the remaining free memory.
  167. .. tip::
  168. You may limit the usage of GPU memory
  169. by adjusting the `gpu_memory_utilization` parameter.
  170. """
  171. # Profile the memory usage of the model and get the maximum number of
  172. # cache blocks that can be allocated with the remaining free memory.
  173. torch.cuda.empty_cache()
  174. tp_rank = get_tensor_model_parallel_rank()
  175. # Execute a forward pass with dummy inputs to profile the memory usage
  176. # of the model.
  177. start = time.time()
  178. self.model_runner.profile_run()
  179. end = time.time()
  180. if tp_rank == 0:
  181. logger.info(f"Model profiling took {end - start:.2f} seconds.")
  182. # Calculate the number of blocks that can be allocated with the
  183. # profiled peak memory.
  184. torch.cuda.synchronize()
  185. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  186. # NOTE: Here we assume that the other processes using the same
  187. # GPU did not change their memory usage during the profiling.
  188. peak_memory = self.init_gpu_memory - free_gpu_memory
  189. assert peak_memory > 0, (
  190. "Error in memory profiling. This happens when the GPU memory was "
  191. "not properly cleaned up before initializing Aphrodite.")
  192. cache_block_size = self.get_cache_block_size_bytes()
  193. if cache_block_size == 0:
  194. num_gpu_blocks = 0
  195. num_cpu_blocks = 0
  196. else:
  197. # if single_user_mode is set to True, we only allocate enough blocks
  198. # for one sequence
  199. if self.scheduler_config.single_user_mode:
  200. num_gpu_blocks = (self.model_config.max_model_len +
  201. self.cache_config.block_size - 1
  202. ) // self.cache_config.block_size
  203. max_possible_blocks = int(
  204. (total_gpu_memory *
  205. self.cache_config.gpu_memory_utilization -
  206. peak_memory) // cache_block_size)
  207. num_gpu_blocks = min(num_gpu_blocks, max_possible_blocks)
  208. if tp_rank == 0:
  209. logger.info(
  210. f"Single sequence mode: Allocating {num_gpu_blocks} "
  211. "blocks "
  212. f"({num_gpu_blocks * self.cache_config.block_size} "
  213. "tokens)")
  214. else:
  215. # Original logic for multi-sequence mode
  216. num_gpu_blocks = int(
  217. (total_gpu_memory *
  218. self.cache_config.gpu_memory_utilization -
  219. peak_memory) // cache_block_size)
  220. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  221. cache_block_size)
  222. num_gpu_blocks = max(num_gpu_blocks, 0)
  223. num_cpu_blocks = max(num_cpu_blocks, 0)
  224. if self.model_runner.lora_manager:
  225. self.model_runner.remove_all_loras()
  226. gc.collect()
  227. torch.cuda.empty_cache()
  228. return num_gpu_blocks, num_cpu_blocks
  229. def initialize_cache(self, num_gpu_blocks: int,
  230. num_cpu_blocks: int) -> None:
  231. """Allocate GPU and CPU KV cache with the specified number of blocks.
  232. This also warms up the model, which may record CUDA graphs.
  233. """
  234. raise_if_cache_size_invalid(num_gpu_blocks,
  235. self.cache_config.block_size,
  236. self.cache_config.is_attention_free,
  237. self.model_config.max_model_len)
  238. self.cache_config.num_gpu_blocks = num_gpu_blocks
  239. self.cache_config.num_cpu_blocks = num_cpu_blocks
  240. self._init_cache_engine()
  241. self._warm_up_model()
  242. def _init_cache_engine(self):
  243. assert self.cache_config.num_gpu_blocks is not None
  244. self.cache_engine = [
  245. CacheEngine(self.cache_config, self.model_config,
  246. self.parallel_config, self.device_config, self.rank)
  247. for _ in range(self.parallel_config.pipeline_parallel_size)
  248. ]
  249. self.gpu_cache = [
  250. self.cache_engine[ve].gpu_cache
  251. for ve in range(self.parallel_config.pipeline_parallel_size)
  252. ]
  253. def _warm_up_model(self) -> None:
  254. if not self.model_config.enforce_eager:
  255. self.model_runner.capture_model(self.gpu_cache)
  256. # Reset the seed to ensure that the random state is not affected by
  257. # the model initialization and profiling.
  258. set_random_seed(self.model_config.seed)
  259. @property
  260. def do_metadata_broadcast(self) -> bool:
  261. return self.parallel_config.tensor_parallel_size > 1
  262. @property
  263. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  264. return self.gpu_cache
  265. @torch.inference_mode()
  266. def prepare_worker_input(
  267. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  268. virtual_engine = execute_model_req.virtual_engine
  269. num_steps = execute_model_req.num_steps
  270. num_seq_groups = len(execute_model_req.seq_group_metadata_list)
  271. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
  272. # they contain parameters to launch cudamemcpyasync.
  273. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
  274. device="cpu",
  275. dtype=torch.int64).view(-1, 2)
  276. blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
  277. device="cpu",
  278. dtype=torch.int64).view(-1, 2)
  279. # `blocks_to_copy` is a gpu tensor. The src and tgt of
  280. # blocks to copy are in the same device, and `blocks_to_copy`
  281. # can be used directly within cuda kernels.
  282. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  283. device=self.device,
  284. dtype=torch.int64).view(-1, 2)
  285. return WorkerInput(num_seq_groups=num_seq_groups,
  286. blocks_to_swap_in=blocks_to_swap_in,
  287. blocks_to_swap_out=blocks_to_swap_out,
  288. blocks_to_copy=blocks_to_copy,
  289. virtual_engine=virtual_engine,
  290. num_steps=num_steps)
  291. @torch.inference_mode()
  292. def execute_worker(self, worker_input: WorkerInput) -> None:
  293. virtual_engine = worker_input.virtual_engine
  294. # Issue cache operations.
  295. if (worker_input.blocks_to_swap_in is not None
  296. and worker_input.blocks_to_swap_in.numel() > 0):
  297. self.cache_engine[virtual_engine].swap_in(
  298. worker_input.blocks_to_swap_in)
  299. if (worker_input.blocks_to_swap_out is not None
  300. and worker_input.blocks_to_swap_out.numel() > 0):
  301. self.cache_engine[virtual_engine].swap_out(
  302. worker_input.blocks_to_swap_out)
  303. if (worker_input.blocks_to_copy is not None
  304. and worker_input.blocks_to_copy.numel() > 0):
  305. self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
  306. def _get_cached_seq_group_metadata(
  307. self,
  308. seq_group_metadata_list: List[Union[SequenceGroupMetadata,
  309. SequenceGroupMetadataDelta]],
  310. finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
  311. """Return a list of cached Sequence Group Metadata after updating its
  312. state.
  313. It is used because scheduler only sends delta to workers to reduce
  314. the data payload size. The function also cleans up cache based on
  315. a given `finished_request_ids`.
  316. """
  317. new_seq_group_metadata_list = []
  318. for metadata_or_delta in seq_group_metadata_list:
  319. request_id = metadata_or_delta.request_id
  320. if request_id not in self._seq_group_metadata_cache:
  321. # The first prefill.
  322. assert isinstance(metadata_or_delta, SequenceGroupMetadata)
  323. self._seq_group_metadata_cache[request_id] = metadata_or_delta
  324. else:
  325. # The first prefill is already cached.
  326. if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
  327. self._seq_group_metadata_cache[request_id].apply_delta(
  328. metadata_or_delta)
  329. else:
  330. # If metadata snapshot is sent again, it is
  331. # preempted. Reset the cache because we need to start
  332. # from scratch.
  333. assert isinstance(metadata_or_delta, SequenceGroupMetadata)
  334. self._seq_group_metadata_cache[
  335. request_id] = metadata_or_delta
  336. new_seq_group_metadata_list.append(
  337. self._seq_group_metadata_cache[request_id])
  338. # Clean up finished ids
  339. for finished_id in finished_request_ids:
  340. del self._seq_group_metadata_cache[finished_id]
  341. return new_seq_group_metadata_list
  342. def _execute_model_spmd(
  343. self,
  344. execute_model_req: ExecuteModelRequest,
  345. intermediate_tensors: Optional[IntermediateTensors] = None,
  346. ) -> Optional[List[SamplerOutput]]:
  347. if execute_model_req is not None:
  348. new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
  349. execute_model_req.seq_group_metadata_list,
  350. execute_model_req.finished_requests_ids)
  351. execute_model_req.seq_group_metadata_list = (
  352. new_seq_group_metadata_list)
  353. output = super()._execute_model_spmd(execute_model_req,
  354. intermediate_tensors)
  355. return output
  356. def add_lora(self, lora_request: LoRARequest) -> bool:
  357. return self.model_runner.add_lora(lora_request)
  358. def remove_lora(self, lora_id: int) -> bool:
  359. return self.model_runner.remove_lora(lora_id)
  360. def pin_lora(self, lora_id: int) -> bool:
  361. return self.model_runner.pin_lora(lora_id)
  362. def list_loras(self) -> Set[int]:
  363. return self.model_runner.list_loras()
  364. def add_prompt_adapter(
  365. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  366. return self.model_runner.add_prompt_adapter(prompt_adapter_request)
  367. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  368. return self.model_runner.remove_lora(prompt_adapter_id)
  369. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  370. return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
  371. def list_prompt_adapters(self) -> Set[int]:
  372. return self.model_runner.list_prompt_adapters()
  373. @property
  374. def max_model_len(self) -> int:
  375. return self.model_config.max_model_len
  376. @property
  377. def vocab_size(self) -> int:
  378. return self.model_runner.vocab_size
  379. def get_cache_block_size_bytes(self) -> int:
  380. """Get the size of the KV cache block size in bytes.
  381. """
  382. return CacheEngine.get_cache_block_size(self.cache_config,
  383. self.model_config,
  384. self.parallel_config)
  385. def init_worker_distributed_environment(
  386. parallel_config: ParallelConfig,
  387. rank: int,
  388. distributed_init_method: Optional[str] = None,
  389. local_rank: int = -1,
  390. ) -> None:
  391. """Initialize the distributed environment."""
  392. set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
  393. init_distributed_environment(parallel_config.world_size, rank,
  394. distributed_init_method, local_rank)
  395. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  396. parallel_config.pipeline_parallel_size)
  397. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  398. # Check if the GPU supports the dtype.
  399. if torch_dtype == torch.bfloat16:
  400. compute_capability = current_platform.get_device_capability()
  401. if compute_capability[0] < 8:
  402. gpu_name = current_platform.get_device_name()
  403. raise ValueError(
  404. "Bfloat16 is only supported on GPUs with compute capability "
  405. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  406. f"{compute_capability[0]}.{compute_capability[1]}. "
  407. "You can use float16 instead by explicitly setting the"
  408. "`dtype` flag in CLI, for example: --dtype=half.")
  409. def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
  410. max_model_len) -> None:
  411. if is_attention_free and num_gpu_blocks != 0:
  412. raise ValueError("No memory should be allocated for the cache blocks "
  413. f"for an attention-free model, but {num_gpu_blocks}"
  414. "blocks are allocated.")
  415. if not is_attention_free and num_gpu_blocks <= 0:
  416. raise ValueError("No available memory for the cache blocks. "
  417. "Try increasing `gpu_memory_utilization` when "
  418. "initializing the engine.")
  419. max_seq_len = block_size * num_gpu_blocks
  420. rank = get_tensor_model_parallel_rank()
  421. if rank == 0:
  422. logger.info(f"Maximum sequence length allowed in the cache: "
  423. f"{max_seq_len}")
  424. if not is_attention_free and max_model_len > max_seq_len:
  425. original_max_model_len = max_model_len
  426. max_model_len = max_seq_len
  427. # raise ValueError(
  428. # f"The model's max seq len ({max_model_len}) "
  429. # "is larger than the maximum number of tokens that can be "
  430. # f"stored in KV cache ({max_seq_len}). Try increasing "
  431. # "`gpu_memory_utilization` or decreasing `max_model_len` when "
  432. # "initializing the engine.")
  433. # set the max_model_len to the max_seq_len, but raise a logger.error
  434. # so the user is made aware of this
  435. logger.error(
  436. f"The model's max seq len ({original_max_model_len}) "
  437. "is larger than the maximum number of tokens that can be "
  438. f"stored in KV cache ({max_seq_len}). "
  439. "Try increasing "
  440. "`gpu_memory_utilization`, setting "
  441. "`--enable-chunked-prefill`, or `--kv-cache-dtype fp8` "
  442. "when initializing the engine. The last two are currently "
  443. "mutually exclusive.\n"
  444. f"Forcing max_model_len to {max_seq_len}.")