worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. """A GPU worker class."""
  2. import os
  3. from typing import Dict, List, Tuple, Optional
  4. import torch
  5. import torch.distributed
  6. from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
  7. SchedulerConfig)
  8. from aphrodite.modeling import get_model, InputMetadata, set_random_seed
  9. from aphrodite.modeling.megatron.parallel_state import (
  10. initialize_model_parallel)
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.common.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
  13. from aphrodite.task_handler.cache_engine import CacheEngine
  14. from aphrodite.common.utils import get_gpu_memory
  15. class Worker:
  16. """A worker class that executes (a partition of) the model on a GPU.
  17. Each worker is associated with a single GPU. The worker is responsible for
  18. maintaining the KV cache and executing the model on the GPU. In case of
  19. distributed inference, each worker is assigned a partition of the model.
  20. """
  21. def __init__(
  22. self,
  23. model_config: ModelConfig,
  24. parallel_config: ParallelConfig,
  25. scheduler_config: SchedulerConfig,
  26. rank: Optional[int] = None,
  27. distributed_init_method: Optional[str] = None,
  28. ) -> None:
  29. self.model_config = model_config
  30. self.parallel_config = parallel_config
  31. self.scheduler_config = scheduler_config
  32. self.rank = rank
  33. self.distributed_init_method = distributed_init_method
  34. # Uninitialized cache engine. Will be initialized by
  35. # self.init_cache_engine().
  36. self.cache_config = None
  37. self.block_size = None
  38. self.cache_engine = None
  39. self.cache_events = None
  40. self.gpu_cache = None
  41. def init_model(self):
  42. # This env var set by Ray causes exceptions with graph building.
  43. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  44. # Env vars will be set by Ray.
  45. self.rank = self.rank if self.rank is not None else int(
  46. os.getenv("RANK", "-1"))
  47. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  48. self.device = torch.device(f"cuda:{local_rank}")
  49. if self.rank < 0:
  50. raise ValueError("Invalid or unspecified rank.")
  51. torch.cuda.set_device(self.device)
  52. # Initialize the distributed environment.
  53. _init_distributed_environment(self.parallel_config, self.rank,
  54. self.distributed_init_method)
  55. # Initialize the model.
  56. set_random_seed(self.model_config.seed)
  57. self.model = get_model(self.model_config)
  58. @torch.inference_mode()
  59. def profile_num_available_blocks(
  60. self,
  61. block_size: int,
  62. gpu_memory_utilization: float,
  63. cpu_swap_space: int,
  64. ) -> Tuple[int, int]:
  65. # Profile the memory usage of the model and get the maximum number of
  66. # cache blocks that can be allocated with the remaining free memory.
  67. torch.cuda.empty_cache()
  68. torch.cuda.reset_peak_memory_stats()
  69. # Profile memory usage with max_num_sequences sequences and the total
  70. # number of tokens equal to max_num_batched_tokens.
  71. # Enable top-k sampling to reflect the accurate memory usage.
  72. vocab_size = self.model.config.vocab_size
  73. sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
  74. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  75. max_num_seqs = self.scheduler_config.max_num_seqs
  76. seqs = []
  77. for group_id in range(max_num_seqs):
  78. seq_len = (max_num_batched_tokens // max_num_seqs +
  79. (group_id < max_num_batched_tokens % max_num_seqs))
  80. seq_data = SequenceData([0] * seq_len)
  81. seq = SequenceGroupMetadata(
  82. request_id=str(group_id),
  83. is_prompt=True,
  84. seq_data={group_id: seq_data},
  85. sampling_params=sampling_params,
  86. block_tables=None,
  87. )
  88. seqs.append(seq)
  89. input_tokens, input_positions, input_metadata = self._prepare_inputs(
  90. seqs)
  91. # Execute the model.
  92. num_layers = self.model_config.get_num_layers(self.parallel_config)
  93. self.model(
  94. input_ids=input_tokens,
  95. positions=input_positions,
  96. kv_caches=[(None, None)] * num_layers,
  97. input_metadata=input_metadata,
  98. cache_events=None,
  99. )
  100. # Calculate the number of blocks that can be allocated with the
  101. # profiled peak memory.
  102. torch.cuda.synchronize()
  103. peak_memory = torch.cuda.max_memory_allocated()
  104. total_gpu_memory = get_gpu_memory()
  105. cache_block_size = CacheEngine.get_cache_block_size(
  106. block_size, self.model_config, self.parallel_config)
  107. num_gpu_blocks = int(
  108. (total_gpu_memory * gpu_memory_utilization - peak_memory) //
  109. cache_block_size)
  110. num_cpu_blocks = int(cpu_swap_space // cache_block_size)
  111. num_gpu_blocks = max(num_gpu_blocks, 0)
  112. num_cpu_blocks = max(num_cpu_blocks, 0)
  113. torch.cuda.empty_cache()
  114. # Reset the seed to ensure that the random state is not affected by
  115. # the model initialization and profiling.
  116. set_random_seed(self.model_config.seed)
  117. return num_gpu_blocks, num_cpu_blocks
  118. def init_cache_engine(self, cache_config: CacheConfig) -> None:
  119. self.cache_config = cache_config
  120. self.block_size = cache_config.block_size
  121. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  122. self.parallel_config)
  123. self.cache_events = self.cache_engine.events
  124. self.gpu_cache = self.cache_engine.gpu_cache
  125. def _prepare_inputs(
  126. self,
  127. seq_group_metadata_list: List[SequenceGroupMetadata],
  128. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
  129. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  130. input_tokens: List[int] = []
  131. input_positions: List[int] = []
  132. slot_mapping: List[int] = []
  133. # Add prompt tokens.
  134. prompt_lens: List[int] = []
  135. for seq_group_metadata in seq_group_metadata_list:
  136. if not seq_group_metadata.is_prompt:
  137. continue
  138. seq_ids = list(seq_group_metadata.seq_data.keys())
  139. sampling_params = seq_group_metadata.sampling_params
  140. seq_groups.append((seq_ids, sampling_params))
  141. # Use any sequence in the group.
  142. seq_id = seq_ids[0]
  143. seq_data = seq_group_metadata.seq_data[seq_id]
  144. prompt_tokens = seq_data.get_token_ids()
  145. prompt_len = len(prompt_tokens)
  146. prompt_lens.append(prompt_len)
  147. input_tokens.extend(prompt_tokens)
  148. # NOTE: Here we assume that the first token in the prompt
  149. # is always the first token in the sequence.
  150. input_positions.extend(range(len(prompt_tokens)))
  151. if seq_group_metadata.block_tables is None:
  152. # During memory profiling, the block tables are not initialized
  153. # yet. In this case, we just use a dummy slot mapping.
  154. slot_mapping.extend([0] * prompt_len)
  155. continue
  156. # Compute the slot mapping.
  157. block_table = seq_group_metadata.block_tables[seq_id]
  158. for i in range(prompt_len):
  159. block_number = block_table[i // self.block_size]
  160. block_offset = i % self.block_size
  161. slot = block_number * self.block_size + block_offset
  162. slot_mapping.append(slot)
  163. # Add generation tokens.
  164. max_context_len = 0
  165. max_num_blocks_per_seq = 0
  166. context_lens: List[int] = []
  167. generation_block_tables: List[List[int]] = []
  168. for seq_group_metadata in seq_group_metadata_list:
  169. if seq_group_metadata.is_prompt:
  170. continue
  171. seq_ids = list(seq_group_metadata.seq_data.keys())
  172. sampling_params = seq_group_metadata.sampling_params
  173. seq_groups.append((seq_ids, sampling_params))
  174. for seq_id in seq_ids:
  175. seq_data = seq_group_metadata.seq_data[seq_id]
  176. generation_token = seq_data.get_last_token_id()
  177. input_tokens.append(generation_token)
  178. context_len = seq_data.get_len()
  179. position = context_len - 1
  180. input_positions.append(position)
  181. block_table = seq_group_metadata.block_tables[seq_id]
  182. generation_block_tables.append(block_table)
  183. max_context_len = max(max_context_len, context_len)
  184. max_num_blocks_per_seq = max(max_num_blocks_per_seq,
  185. len(block_table))
  186. context_lens.append(context_len)
  187. block_number = block_table[position // self.block_size]
  188. block_offset = position % self.block_size
  189. slot = block_number * self.block_size + block_offset
  190. slot_mapping.append(slot)
  191. # Optimization: Pad the input length to be a multiple of 8.
  192. # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
  193. input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
  194. input_positions = _pad_to_alignment(input_positions, multiple_of=8)
  195. # Convert to tensors.
  196. tokens_tensor = torch.cuda.LongTensor(input_tokens)
  197. positions_tensor = torch.cuda.LongTensor(input_positions)
  198. slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
  199. context_lens_tensor = torch.cuda.IntTensor(context_lens)
  200. padded_block_tables = [
  201. _pad_to_max(block_table, max_num_blocks_per_seq)
  202. for block_table in generation_block_tables
  203. ]
  204. block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
  205. seq_data: Dict[int, SequenceData] = {}
  206. for seq_group_metadata in seq_group_metadata_list:
  207. seq_data.update(seq_group_metadata.seq_data)
  208. input_metadata = InputMetadata(
  209. seq_groups=seq_groups,
  210. seq_data=seq_data,
  211. prompt_lens=prompt_lens,
  212. slot_mapping=slot_mapping_tensor,
  213. context_lens=context_lens_tensor,
  214. max_context_len=max_context_len,
  215. block_tables=block_tables_tensor,
  216. )
  217. return tokens_tensor, positions_tensor, input_metadata
  218. @torch.inference_mode()
  219. def execute_model(
  220. self,
  221. seq_group_metadata_list: List[SequenceGroupMetadata],
  222. blocks_to_swap_in: Dict[int, int],
  223. blocks_to_swap_out: Dict[int, int],
  224. blocks_to_copy: Dict[int, List[int]],
  225. ) -> SamplerOutput:
  226. # Issue cache operations.
  227. issued_cache_op = False
  228. if blocks_to_swap_in:
  229. self.cache_engine.swap_in(blocks_to_swap_in)
  230. issued_cache_op = True
  231. if blocks_to_swap_out:
  232. self.cache_engine.swap_out(blocks_to_swap_out)
  233. issued_cache_op = True
  234. if blocks_to_copy:
  235. self.cache_engine.copy(blocks_to_copy)
  236. issued_cache_op = True
  237. if issued_cache_op:
  238. cache_events = self.cache_events
  239. else:
  240. cache_events = None
  241. # If there is no input, we don't need to execute the model.
  242. if not seq_group_metadata_list:
  243. if cache_events is not None:
  244. for event in cache_events:
  245. event.wait()
  246. return {}
  247. # Prepare input tensors.
  248. input_tokens, input_positions, input_metadata = self._prepare_inputs(
  249. seq_group_metadata_list)
  250. # Execute the model.
  251. output = self.model(
  252. input_ids=input_tokens,
  253. positions=input_positions,
  254. kv_caches=self.gpu_cache,
  255. input_metadata=input_metadata,
  256. cache_events=cache_events,
  257. )
  258. return output
  259. def _init_distributed_environment(
  260. parallel_config: ParallelConfig,
  261. rank: int,
  262. distributed_init_method: Optional[str] = None,
  263. ) -> None:
  264. """Initialize the distributed environment."""
  265. if torch.distributed.is_initialized():
  266. torch_world_size = torch.distributed.get_world_size()
  267. if torch_world_size != parallel_config.world_size:
  268. raise RuntimeError(
  269. "torch.distributed is already initialized but the torch world "
  270. "size does not match parallel_config.world_size "
  271. f"({torch_world_size} vs. {parallel_config.world_size}).")
  272. elif not distributed_init_method:
  273. raise ValueError(
  274. "distributed_init_method must be set if torch.distributed "
  275. "is not already initialized")
  276. else:
  277. torch.distributed.init_process_group(
  278. backend="nccl",
  279. world_size=parallel_config.world_size,
  280. rank=rank,
  281. init_method=distributed_init_method,
  282. )
  283. # A small all_reduce for warmup.
  284. torch.distributed.all_reduce(torch.zeros(1).cuda())
  285. initialize_model_parallel(parallel_config.tensor_parallel_size,
  286. parallel_config.pipeline_parallel_size)
  287. def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
  288. return x + [0] * ((-len(x)) % multiple_of)
  289. def _pad_to_max(x: List[int], max_len: int) -> List[int]:
  290. return x + [0] * (max_len - len(x))