1
0

worker.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. """A GPU worker class."""
  2. import gc
  3. import os
  4. import time
  5. from typing import List, Optional, Set, Tuple, Type
  6. import torch
  7. import torch.distributed
  8. from loguru import logger
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, ParallelConfig,
  11. PromptAdapterConfig, SchedulerConfig,
  12. SpeculativeConfig)
  13. from aphrodite.common.sequence import ExecuteModelRequest
  14. from aphrodite.distributed import (ensure_model_parallel_initialized,
  15. get_tensor_model_parallel_rank,
  16. get_tensor_model_parallel_world_size,
  17. init_distributed_environment,
  18. set_custom_all_reduce)
  19. from aphrodite.lora.request import LoRARequest
  20. from aphrodite.modeling import set_random_seed
  21. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  22. from aphrodite.platforms import current_platform
  23. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  24. from aphrodite.task_handler.cache_engine import CacheEngine
  25. from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
  26. from aphrodite.task_handler.enc_dec_model_runner import (
  27. EncoderDecoderModelRunner)
  28. from aphrodite.task_handler.model_runner import GPUModelRunnerBase, ModelRunner
  29. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  30. WorkerInput)
  31. class Worker(LocalOrDistributedWorkerBase):
  32. """A worker class that executes (a partition of) the model on a GPU.
  33. Each worker is associated with a single GPU. The worker is responsible for
  34. maintaining the KV cache and executing the model on the GPU. In case of
  35. distributed inference, each worker is assigned a partition of the model.
  36. """
  37. def __init__(
  38. self,
  39. model_config: ModelConfig,
  40. parallel_config: ParallelConfig,
  41. scheduler_config: SchedulerConfig,
  42. device_config: DeviceConfig,
  43. cache_config: CacheConfig,
  44. load_config: LoadConfig,
  45. local_rank: int,
  46. rank: int,
  47. distributed_init_method: str,
  48. lora_config: Optional[LoRAConfig] = None,
  49. speculative_config: Optional[SpeculativeConfig] = None,
  50. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  51. is_driver_worker: bool = False,
  52. model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
  53. ) -> None:
  54. self.model_config = model_config
  55. self.parallel_config = parallel_config
  56. self.parallel_config.rank = rank
  57. self.scheduler_config = scheduler_config
  58. self.device_config = device_config
  59. self.cache_config = cache_config
  60. self.local_rank = local_rank
  61. self.rank = rank
  62. self.distributed_init_method = distributed_init_method
  63. self.lora_config = lora_config
  64. self.prompt_adapter_config = prompt_adapter_config
  65. self.load_config = load_config
  66. self.is_driver_worker = is_driver_worker
  67. if parallel_config and is_driver_worker:
  68. assert rank % parallel_config.tensor_parallel_size == 0, \
  69. "Driver worker should be rank 0 of tensor parallel group."
  70. if self.model_config.trust_remote_code:
  71. # note: lazy import to avoid importing torch before initializing
  72. from aphrodite.common.utils import init_cached_hf_modules
  73. init_cached_hf_modules()
  74. # Return hidden states from target model if the draft model is an
  75. # mlp_speculator
  76. speculative_args = {} if speculative_config is None \
  77. or (speculative_config.draft_model_config.model ==
  78. model_config.model) \
  79. or (speculative_config.draft_model_config.hf_config.model_type
  80. not in ["medusa", "mlp_speculator"]) \
  81. else {"return_hidden_states": True}
  82. ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
  83. if model_runner_cls is not None:
  84. ModelRunnerClass = model_runner_cls
  85. elif self._is_embedding_model():
  86. ModelRunnerClass = EmbeddingModelRunner
  87. elif self._is_encoder_decoder_model():
  88. ModelRunnerClass = EncoderDecoderModelRunner
  89. self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
  90. model_config,
  91. parallel_config,
  92. scheduler_config,
  93. device_config,
  94. cache_config,
  95. load_config=load_config,
  96. lora_config=self.lora_config,
  97. kv_cache_dtype=self.cache_config.cache_dtype,
  98. is_driver_worker=is_driver_worker,
  99. prompt_adapter_config=prompt_adapter_config,
  100. tp_rank=self.rank,
  101. **speculative_args,
  102. )
  103. # Uninitialized cache engine. Will be initialized by
  104. # initialize_cache.
  105. self.cache_engine: List[CacheEngine]
  106. # Initialize gpu_cache as embedding models don't initialize kv_caches
  107. self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
  108. def _is_encoder_decoder_model(self):
  109. return self.model_config.is_encoder_decoder_model
  110. def _is_embedding_model(self):
  111. return self.model_config.is_embedding_model
  112. def init_device(self) -> None:
  113. if self.device_config.device.type == "cuda":
  114. # torch.distributed.all_reduce does not free the input tensor until
  115. # the synchronization point. This causes the memory usage to grow
  116. # as the number of all_reduce calls increases. This env var disables
  117. # this behavior.
  118. # Related issue:
  119. # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
  120. os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
  121. # This env var set by Ray causes exceptions with graph building.
  122. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  123. self.device = torch.device(f"cuda:{self.local_rank}")
  124. torch.cuda.set_device(self.device)
  125. _check_if_gpu_supports_dtype(self.model_config.dtype)
  126. torch.cuda.empty_cache()
  127. self.init_gpu_memory = torch.cuda.mem_get_info()[0]
  128. else:
  129. raise RuntimeError(
  130. f"Not support device type: {self.device_config.device}")
  131. # Initialize the distributed environment.
  132. init_worker_distributed_environment(self.parallel_config, self.rank,
  133. self.distributed_init_method,
  134. self.local_rank)
  135. # Set random seed.
  136. set_random_seed(self.model_config.seed)
  137. def load_model(self):
  138. self.model_runner.load_model()
  139. def save_sharded_state(
  140. self,
  141. path: str,
  142. pattern: Optional[str] = None,
  143. max_size: Optional[int] = None,
  144. ) -> None:
  145. self.model_runner.save_sharded_state(
  146. path,
  147. pattern=pattern,
  148. max_size=max_size,
  149. )
  150. def save_tensorized_model(
  151. self,
  152. tensorizer_config: TensorizerConfig,
  153. ) -> None:
  154. self.model_runner.save_tensorized_model(
  155. tensorizer_config=tensorizer_config, )
  156. @torch.inference_mode()
  157. def determine_num_available_blocks(self) -> Tuple[int, int]:
  158. """Profiles the peak memory usage of the model to determine how many
  159. KV blocks may be allocated without OOMs.
  160. The engine will first conduct a profiling of the existing memory usage.
  161. Then, it calculate the maximum possible number of GPU and CPU blocks
  162. that can be allocated with the remaining free memory.
  163. .. tip::
  164. You may limit the usage of GPU memory
  165. by adjusting the `gpu_memory_utilization` parameter.
  166. """
  167. # Profile the memory usage of the model and get the maximum number of
  168. # cache blocks that can be allocated with the remaining free memory.
  169. torch.cuda.empty_cache()
  170. tp_rank = get_tensor_model_parallel_rank()
  171. world_size = get_tensor_model_parallel_world_size()
  172. # Execute a forward pass with dummy inputs to profile the memory usage
  173. # of the model.
  174. start = time.time()
  175. self.model_runner.profile_run()
  176. end = time.time()
  177. if tp_rank == 0:
  178. logger.info(f"Model profiling took {end - start:.2f} seconds.")
  179. # Calculate the number of blocks that can be allocated with the
  180. # profiled peak memory.
  181. torch.cuda.synchronize()
  182. free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  183. # NOTE: Here we assume that the other processes using the same
  184. # GPU did not change their memory usage during the profiling.
  185. peak_memory = self.init_gpu_memory - free_gpu_memory
  186. weights_memory = self.model_runner.get_model_memory_usage()
  187. kv_cache_memory = peak_memory - weights_memory
  188. # log peak memory usage in GB
  189. if world_size > 1:
  190. if tp_rank == 0:
  191. logger.info(f"Estimated KV Cache memory usage: "
  192. f"{(kv_cache_memory) / 1e9:.2f} x {world_size} = "
  193. f"{(kv_cache_memory * world_size) / 1e9:.2f} GB")
  194. else:
  195. logger.info(f"Estimated KV Cache memory usage: "
  196. f"{(kv_cache_memory) / 1e9:.2f} GB")
  197. assert peak_memory > 0, (
  198. "Error in memory profiling. This happens when the GPU memory was "
  199. "not properly cleaned up before initializing Aphrodite.")
  200. cache_block_size = self.get_cache_block_size_bytes()
  201. if cache_block_size == 0:
  202. num_gpu_blocks = 0
  203. num_cpu_blocks = 0
  204. else:
  205. num_gpu_blocks = int(
  206. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  207. peak_memory) // cache_block_size)
  208. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  209. cache_block_size)
  210. num_gpu_blocks = max(num_gpu_blocks, 0)
  211. num_cpu_blocks = max(num_cpu_blocks, 0)
  212. if self.model_runner.lora_manager:
  213. self.model_runner.remove_all_loras()
  214. gc.collect()
  215. torch.cuda.empty_cache()
  216. return num_gpu_blocks, num_cpu_blocks
  217. def initialize_cache(self, num_gpu_blocks: int,
  218. num_cpu_blocks: int) -> None:
  219. """Allocate GPU and CPU KV cache with the specified number of blocks.
  220. This also warms up the model, which may record CUDA graphs.
  221. """
  222. raise_if_cache_size_invalid(num_gpu_blocks,
  223. self.cache_config.block_size,
  224. self.cache_config.is_attention_free,
  225. self.model_config.max_model_len)
  226. self.cache_config.num_gpu_blocks = num_gpu_blocks
  227. self.cache_config.num_cpu_blocks = num_cpu_blocks
  228. self._init_cache_engine()
  229. self._warm_up_model()
  230. def _init_cache_engine(self):
  231. assert self.cache_config.num_gpu_blocks is not None
  232. self.cache_engine = [
  233. CacheEngine(self.cache_config, self.model_config,
  234. self.parallel_config, self.device_config, self.rank)
  235. for _ in range(self.parallel_config.pipeline_parallel_size)
  236. ]
  237. self.gpu_cache = [
  238. self.cache_engine[ve].gpu_cache
  239. for ve in range(self.parallel_config.pipeline_parallel_size)
  240. ]
  241. def _warm_up_model(self) -> None:
  242. if not self.model_config.enforce_eager:
  243. self.model_runner.capture_model(self.gpu_cache)
  244. # Reset the seed to ensure that the random state is not affected by
  245. # the model initialization and profiling.
  246. set_random_seed(self.model_config.seed)
  247. @property
  248. def do_metadata_broadcast(self) -> bool:
  249. return self.parallel_config.tensor_parallel_size > 1
  250. @property
  251. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  252. return self.gpu_cache
  253. @torch.inference_mode()
  254. def prepare_worker_input(
  255. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  256. virtual_engine = execute_model_req.virtual_engine
  257. num_steps = execute_model_req.num_steps
  258. num_seq_groups = len(execute_model_req.seq_group_metadata_list)
  259. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
  260. # they contain parameters to launch cudamemcpyasync.
  261. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
  262. device="cpu",
  263. dtype=torch.int64).view(-1, 2)
  264. blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
  265. device="cpu",
  266. dtype=torch.int64).view(-1, 2)
  267. # `blocks_to_copy` is a gpu tensor. The src and tgt of
  268. # blocks to copy are in the same device, and `blocks_to_copy`
  269. # can be used directly within cuda kernels.
  270. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
  271. device=self.device,
  272. dtype=torch.int64).view(-1, 2)
  273. return WorkerInput(num_seq_groups=num_seq_groups,
  274. blocks_to_swap_in=blocks_to_swap_in,
  275. blocks_to_swap_out=blocks_to_swap_out,
  276. blocks_to_copy=blocks_to_copy,
  277. virtual_engine=virtual_engine,
  278. num_steps=num_steps)
  279. @torch.inference_mode()
  280. def execute_worker(self, worker_input: WorkerInput) -> None:
  281. virtual_engine = worker_input.virtual_engine
  282. # Issue cache operations.
  283. if (worker_input.blocks_to_swap_in is not None
  284. and worker_input.blocks_to_swap_in.numel() > 0):
  285. self.cache_engine[virtual_engine].swap_in(
  286. worker_input.blocks_to_swap_in)
  287. if (worker_input.blocks_to_swap_out is not None
  288. and worker_input.blocks_to_swap_out.numel() > 0):
  289. self.cache_engine[virtual_engine].swap_out(
  290. worker_input.blocks_to_swap_out)
  291. if (worker_input.blocks_to_copy is not None
  292. and worker_input.blocks_to_copy.numel() > 0):
  293. self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
  294. def add_lora(self, lora_request: LoRARequest) -> bool:
  295. return self.model_runner.add_lora(lora_request)
  296. def remove_lora(self, lora_id: int) -> bool:
  297. return self.model_runner.remove_lora(lora_id)
  298. def pin_lora(self, lora_id: int) -> bool:
  299. return self.model_runner.pin_lora(lora_id)
  300. def list_loras(self) -> Set[int]:
  301. return self.model_runner.list_loras()
  302. def add_prompt_adapter(
  303. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  304. return self.model_runner.add_prompt_adapter(prompt_adapter_request)
  305. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  306. return self.model_runner.remove_lora(prompt_adapter_id)
  307. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  308. return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
  309. def list_prompt_adapters(self) -> Set[int]:
  310. return self.model_runner.list_prompt_adapters()
  311. @property
  312. def max_model_len(self) -> int:
  313. return self.model_config.max_model_len
  314. @property
  315. def vocab_size(self) -> int:
  316. return self.model_runner.vocab_size
  317. def get_cache_block_size_bytes(self) -> int:
  318. """Get the size of the KV cache block size in bytes.
  319. """
  320. return CacheEngine.get_cache_block_size(self.cache_config,
  321. self.model_config,
  322. self.parallel_config)
  323. def init_worker_distributed_environment(
  324. parallel_config: ParallelConfig,
  325. rank: int,
  326. distributed_init_method: Optional[str] = None,
  327. local_rank: int = -1,
  328. ) -> None:
  329. """Initialize the distributed environment."""
  330. set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
  331. init_distributed_environment(parallel_config.world_size, rank,
  332. distributed_init_method, local_rank)
  333. ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
  334. parallel_config.pipeline_parallel_size)
  335. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  336. # Check if the GPU supports the dtype.
  337. if torch_dtype == torch.bfloat16:
  338. compute_capability = current_platform.get_device_capability()
  339. if compute_capability[0] < 8:
  340. gpu_name = current_platform.get_device_name()
  341. raise ValueError(
  342. "Bfloat16 is only supported on GPUs with compute capability "
  343. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  344. f"{compute_capability[0]}.{compute_capability[1]}. "
  345. "You can use float16 instead by explicitly setting the"
  346. "`dtype` flag in CLI, for example: --dtype=half.")
  347. def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
  348. max_model_len) -> None:
  349. if is_attention_free and num_gpu_blocks != 0:
  350. raise ValueError("No memory should be allocated for the cache blocks "
  351. f"for an attention-free model, but {num_gpu_blocks}"
  352. "blocks are allocated.")
  353. if not is_attention_free and num_gpu_blocks <= 0:
  354. raise ValueError("No available memory for the cache blocks. "
  355. "Try increasing `gpu_memory_utilization` when "
  356. "initializing the engine.")
  357. max_seq_len = block_size * num_gpu_blocks
  358. rank = get_tensor_model_parallel_rank()
  359. if rank == 0:
  360. logger.info(f"Maximum sequence length allowed in the cache: "
  361. f"{max_seq_len}")
  362. if not is_attention_free and max_model_len > max_seq_len:
  363. original_max_model_len = max_model_len
  364. max_model_len = max_seq_len
  365. # raise ValueError(
  366. # f"The model's max seq len ({max_model_len}) "
  367. # "is larger than the maximum number of tokens that can be "
  368. # f"stored in KV cache ({max_seq_len}). Try increasing "
  369. # "`gpu_memory_utilization` or decreasing `max_model_len` when "
  370. # "initializing the engine.")
  371. # set the max_model_len to the max_seq_len, but raise a logger.error
  372. # so the user is made aware of this
  373. logger.error(
  374. f"The model's max seq len ({original_max_model_len}) "
  375. "is larger than the maximum number of tokens that can be "
  376. f"stored in KV cache ({max_seq_len}). "
  377. "Try increasing "
  378. "`gpu_memory_utilization`, setting "
  379. "`--enable-chunked-prefill`, or `--kv-cache-dtype fp8` "
  380. "when initializing the engine. The last two are currently "
  381. "mutually exclusive.\n"
  382. f"Forcing max_model_len to {max_seq_len}.")