worker.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. """A GPU worker class."""
  2. import os
  3. from typing import Dict, List, Tuple, Optional
  4. import torch
  5. import torch.distributed
  6. from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
  7. SchedulerConfig)
  8. from aphrodite.modeling import get_model, InputMetadata, set_random_seed
  9. from aphrodite.modeling.megatron.parallel_state import (
  10. initialize_model_parallel)
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.common.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
  13. from aphrodite.task_handler.cache_engine import CacheEngine
  14. from aphrodite.common.utils import get_gpu_memory, get_max_shared_memory_bytes
  15. class Worker:
  16. """A worker class that executes (a partition of) the model on a GPU.
  17. Each worker is associated with a single GPU. The worker is responsible for
  18. maintaining the KV cache and executing the model on the GPU. In case of
  19. distributed inference, each worker is assigned a partition of the model.
  20. """
  21. def __init__(
  22. self,
  23. model_config: ModelConfig,
  24. parallel_config: ParallelConfig,
  25. scheduler_config: SchedulerConfig,
  26. rank: Optional[int] = None,
  27. distributed_init_method: Optional[str] = None,
  28. ) -> None:
  29. self.model_config = model_config
  30. self.parallel_config = parallel_config
  31. self.scheduler_config = scheduler_config
  32. self.rank = rank
  33. self.distributed_init_method = distributed_init_method
  34. # Uninitialized cache engine. Will be initialized by
  35. # self.init_cache_engine().
  36. self.cache_config = None
  37. self.block_size = None
  38. self.sliding_window = None
  39. self.cache_engine = None
  40. self.cache_events = None
  41. self.gpu_cache = None
  42. def init_model(self):
  43. # This env var set by Ray causes exceptions with graph building.
  44. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
  45. # Env vars will be set by Ray.
  46. self.rank = self.rank if self.rank is not None else int(
  47. os.getenv("RANK", "-1"))
  48. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  49. self.device = torch.device(f"cuda:{local_rank}")
  50. if self.rank < 0:
  51. raise ValueError("Invalid or unspecified rank.")
  52. torch.cuda.set_device(self.device)
  53. _check_if_gpu_supports_dtype(self.model_config.dtype)
  54. # Initialize the distributed environment.
  55. _init_distributed_environment(self.parallel_config, self.rank,
  56. self.distributed_init_method)
  57. # Initialize the model.
  58. set_random_seed(self.model_config.seed)
  59. self.model = get_model(self.model_config,
  60. self.scheduler_config.max_num_batched_tokens)
  61. @torch.inference_mode()
  62. def profile_num_available_blocks(
  63. self,
  64. block_size: int,
  65. gpu_memory_utilization: float,
  66. cpu_swap_space: int,
  67. ) -> Tuple[int, int]:
  68. # Profile the memory usage of the model and get the maximum number of
  69. # cache blocks that can be allocated with the remaining free memory.
  70. torch.cuda.empty_cache()
  71. torch.cuda.reset_peak_memory_stats()
  72. # Profile memory usage with max_num_sequences sequences and the total
  73. # number of tokens equal to max_num_batched_tokens.
  74. # Enable top-k sampling to reflect the accurate memory usage.
  75. vocab_size = self.model.config.vocab_size
  76. sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
  77. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  78. max_num_seqs = self.scheduler_config.max_num_seqs
  79. seqs = []
  80. for group_id in range(max_num_seqs):
  81. seq_len = (max_num_batched_tokens // max_num_seqs +
  82. (group_id < max_num_batched_tokens % max_num_seqs))
  83. seq_data = SequenceData([0] * seq_len)
  84. seq = SequenceGroupMetadata(
  85. request_id=str(group_id),
  86. is_prompt=True,
  87. seq_data={group_id: seq_data},
  88. sampling_params=sampling_params,
  89. block_tables=None,
  90. )
  91. seqs.append(seq)
  92. input_tokens, input_positions, input_metadata = self._prepare_inputs(
  93. seqs)
  94. # Execute the model.
  95. num_layers = self.model_config.get_num_layers(self.parallel_config)
  96. self.model(
  97. input_ids=input_tokens,
  98. positions=input_positions,
  99. kv_caches=[(None, None)] * num_layers,
  100. input_metadata=input_metadata,
  101. cache_events=None,
  102. )
  103. # Calculate the number of blocks that can be allocated with the
  104. # profiled peak memory.
  105. torch.cuda.synchronize()
  106. peak_memory = torch.cuda.max_memory_allocated()
  107. total_gpu_memory = get_gpu_memory()
  108. cache_block_size = CacheEngine.get_cache_block_size(
  109. block_size, self.model_config, self.parallel_config)
  110. num_gpu_blocks = int(
  111. (total_gpu_memory * gpu_memory_utilization - peak_memory) //
  112. cache_block_size)
  113. num_cpu_blocks = int(cpu_swap_space // cache_block_size)
  114. num_gpu_blocks = max(num_gpu_blocks, 0)
  115. num_cpu_blocks = max(num_cpu_blocks, 0)
  116. torch.cuda.empty_cache()
  117. # Reset the seed to ensure that the random state is not affected by
  118. # the model initialization and profiling.
  119. set_random_seed(self.model_config.seed)
  120. return num_gpu_blocks, num_cpu_blocks
  121. def init_cache_engine(self, cache_config: CacheConfig) -> None:
  122. self.cache_config = cache_config
  123. self.block_size = cache_config.block_size
  124. self.sliding_window = cache_config.sliding_window
  125. if self.sliding_window is None:
  126. max_seq_len = self.scheduler_config.max_model_len
  127. else:
  128. max_seq_len = min(self.scheduler_config.max_model_len,
  129. self.sliding_window)
  130. _check_if_can_support_max_seq_len(max_seq_len, self.block_size)
  131. self.cache_engine = CacheEngine(self.cache_config, self.model_config,
  132. self.parallel_config)
  133. self.cache_events = self.cache_engine.events
  134. self.gpu_cache = self.cache_engine.gpu_cache
  135. def _prepare_inputs(
  136. self,
  137. seq_group_metadata_list: List[SequenceGroupMetadata],
  138. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
  139. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  140. input_tokens: List[List[int]] = []
  141. input_positions: List[List[int]] = []
  142. slot_mapping: List[List[int]] = []
  143. # Add prompt tokens.
  144. prompt_lens: List[int] = []
  145. for seq_group_metadata in seq_group_metadata_list:
  146. if not seq_group_metadata.is_prompt:
  147. continue
  148. seq_ids = list(seq_group_metadata.seq_data.keys())
  149. sampling_params = seq_group_metadata.sampling_params
  150. seq_groups.append((seq_ids, sampling_params))
  151. # Use any sequence in the group.
  152. seq_id = seq_ids[0]
  153. seq_data = seq_group_metadata.seq_data[seq_id]
  154. prompt_tokens = seq_data.get_token_ids()
  155. prompt_len = len(prompt_tokens)
  156. prompt_lens.append(prompt_len)
  157. input_tokens.append(prompt_tokens)
  158. # NOTE: Here we assume that the first token in the prompt
  159. # is always the first token in the sequence.
  160. input_positions.append(list(range(prompt_len)))
  161. if seq_group_metadata.block_tables is None:
  162. # During memory profiling, the block tables are not initialized
  163. # yet. In this case, we just use a dummy slot mapping.
  164. slot_mapping.append([0] * prompt_len)
  165. continue
  166. # Compute the slot mapping.
  167. slot_mapping.append([])
  168. block_table = seq_group_metadata.block_tables[seq_id]
  169. for i in range(prompt_len):
  170. block_number = block_table[i // self.block_size]
  171. block_offset = i % self.block_size
  172. slot = block_number * self.block_size + block_offset
  173. slot_mapping[-1].append(slot)
  174. # Add generation tokens.
  175. max_context_len = 0
  176. max_num_blocks_per_seq = 0
  177. context_lens: List[int] = []
  178. generation_block_tables: List[List[int]] = []
  179. for seq_group_metadata in seq_group_metadata_list:
  180. if seq_group_metadata.is_prompt:
  181. continue
  182. seq_ids = list(seq_group_metadata.seq_data.keys())
  183. sampling_params = seq_group_metadata.sampling_params
  184. seq_groups.append((seq_ids, sampling_params))
  185. for seq_id in seq_ids:
  186. seq_data = seq_group_metadata.seq_data[seq_id]
  187. generation_token = seq_data.get_last_token_id()
  188. input_tokens.append([generation_token])
  189. context_len = seq_data.get_len()
  190. position = context_len - 1
  191. if self.sliding_window is not None:
  192. context_len = min(context_len, self.sliding_window)
  193. input_positions.append([position])
  194. block_table = seq_group_metadata.block_tables[seq_id]
  195. max_context_len = max(max_context_len, context_len)
  196. max_num_blocks_per_seq = max(max_num_blocks_per_seq,
  197. len(block_table))
  198. context_lens.append(context_len)
  199. block_number = block_table[position // self.block_size]
  200. block_offset = position % self.block_size
  201. slot = block_number * self.block_size + block_offset
  202. slot_mapping.append([slot])
  203. if self.sliding_window:
  204. assert self.cache_config is not None
  205. sliding_window_blocks = (self.sliding_window //
  206. self.block_size)
  207. block_table = block_table[-sliding_window_blocks:]
  208. generation_block_tables.append(block_table)
  209. # NOTE: This part was optimized!
  210. max_seq_len = max(prompt_lens) if prompt_lens else 1
  211. padded_input_tokens = [
  212. _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
  213. ]
  214. padded_input_positions = [
  215. _pad_to_max(positions, max_seq_len, pad=0)
  216. for positions in input_positions
  217. ]
  218. padded_slot_mapping = [
  219. _pad_to_max(mapping, max_seq_len, pad=-1)
  220. for mapping in slot_mapping
  221. ]
  222. padded_block_tables = [
  223. _pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
  224. for block_table in generation_block_tables
  225. ]
  226. # Convert to tensors.
  227. tokens_tensor = torch.tensor(padded_input_tokens,
  228. dtype=torch.long,
  229. device="cuda")
  230. positions_tensor = torch.tensor(padded_input_positions,
  231. dtype=torch.long,
  232. device="cuda")
  233. slot_mapping_tensor = torch.tensor(padded_slot_mapping,
  234. dtype=torch.int,
  235. device="cuda")
  236. context_lens_tensor = torch.tensor(context_lens,
  237. dtype=torch.int,
  238. device="cuda")
  239. block_tables_tensor = torch.tensor(padded_block_tables,
  240. dtype=torch.int,
  241. device="cuda")
  242. seq_data: Dict[int, SequenceData] = {}
  243. for seq_group_metadata in seq_group_metadata_list:
  244. seq_data.update(seq_group_metadata.seq_data)
  245. input_metadata = InputMetadata(
  246. seq_groups=seq_groups,
  247. seq_data=seq_data,
  248. prompt_lens=prompt_lens,
  249. slot_mapping=slot_mapping_tensor,
  250. context_lens=context_lens_tensor,
  251. max_context_len=max_context_len,
  252. block_tables=block_tables_tensor,
  253. sliding_window=self.sliding_window,
  254. )
  255. return tokens_tensor, positions_tensor, input_metadata
  256. @torch.inference_mode()
  257. def execute_model(
  258. self,
  259. seq_group_metadata_list: List[SequenceGroupMetadata],
  260. blocks_to_swap_in: Dict[int, int],
  261. blocks_to_swap_out: Dict[int, int],
  262. blocks_to_copy: Dict[int, List[int]],
  263. ) -> SamplerOutput:
  264. # Issue cache operations.
  265. issued_cache_op = False
  266. if blocks_to_swap_in:
  267. self.cache_engine.swap_in(blocks_to_swap_in)
  268. issued_cache_op = True
  269. if blocks_to_swap_out:
  270. self.cache_engine.swap_out(blocks_to_swap_out)
  271. issued_cache_op = True
  272. if blocks_to_copy:
  273. self.cache_engine.copy(blocks_to_copy)
  274. issued_cache_op = True
  275. if issued_cache_op:
  276. cache_events = self.cache_events
  277. else:
  278. cache_events = None
  279. # If there is no input, we don't need to execute the model.
  280. if not seq_group_metadata_list:
  281. if cache_events is not None:
  282. for event in cache_events:
  283. event.wait()
  284. return {}
  285. # Prepare input tensors.
  286. input_tokens, input_positions, input_metadata = self._prepare_inputs(
  287. seq_group_metadata_list)
  288. # Execute the model.
  289. output = self.model(
  290. input_ids=input_tokens,
  291. positions=input_positions,
  292. kv_caches=self.gpu_cache,
  293. input_metadata=input_metadata,
  294. cache_events=cache_events,
  295. )
  296. return output
  297. def _init_distributed_environment(
  298. parallel_config: ParallelConfig,
  299. rank: int,
  300. distributed_init_method: Optional[str] = None,
  301. ) -> None:
  302. """Initialize the distributed environment."""
  303. if torch.distributed.is_initialized():
  304. torch_world_size = torch.distributed.get_world_size()
  305. if torch_world_size != parallel_config.world_size:
  306. raise RuntimeError(
  307. "torch.distributed is already initialized but the torch world "
  308. "size does not match parallel_config.world_size "
  309. f"({torch_world_size} vs. {parallel_config.world_size}).")
  310. elif not distributed_init_method:
  311. raise ValueError(
  312. "distributed_init_method must be set if torch.distributed "
  313. "is not already initialized")
  314. else:
  315. torch.distributed.init_process_group(
  316. backend="nccl",
  317. world_size=parallel_config.world_size,
  318. rank=rank,
  319. init_method=distributed_init_method,
  320. )
  321. # A small all_reduce for warmup.
  322. torch.distributed.all_reduce(torch.zeros(1).cuda())
  323. initialize_model_parallel(parallel_config.tensor_parallel_size,
  324. parallel_config.pipeline_parallel_size)
  325. # TODO: Check if this needs to be removed.
  326. def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
  327. return x + [pad] * ((-len(x)) % multiple_of)
  328. def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
  329. return x + [pad] * (max_len - len(x))
  330. def _check_if_can_support_max_seq_len(max_seq_len: int,
  331. block_size: int) -> None:
  332. # Follows the logic in
  333. # attention_kernels.cu::single_query_cached_kv_attention_launcher
  334. max_shared_mem = get_max_shared_memory_bytes()
  335. float32_bytes = torch.finfo(torch.float).bits // 8
  336. padded_max_seq_len = (
  337. (max_seq_len + block_size - 1) / block_size) * block_size
  338. # padded_max_seq_len + extra buffer
  339. required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
  340. if padded_max_seq_len * float32_bytes > max_shared_mem:
  341. raise RuntimeError(
  342. f"Aphrodite cannot currently support max_model_len={max_seq_len} "
  343. f"with block_size={block_size} on GPU with compute "
  344. f"capability {torch.cuda.get_device_capability()} "
  345. f"(required shared memory {required_shared_mem} > "
  346. f"available shared memory {max_shared_mem}). "
  347. "This will be fixed in a future release.")
  348. def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  349. if torch_dtype == torch.bfloat16:
  350. compute_capability = torch.cuda.get_device_capability()
  351. if compute_capability[0] < 8:
  352. gpu_name = torch.cuda.get_device_name()
  353. raise ValueError(
  354. "Bfloat16 is only supported on GPUs with compute capability "
  355. f"of at least 8.0. You {gpu_name} GPU has compute capability "
  356. f"{compute_capability[0]}.{compute_capability[1]}. Please "
  357. "use the `--dtype float16` argument when launching the engine."
  358. )