model_runner.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. import time
  2. from typing import Dict, List, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from aphrodite.common.config import (ModelConfig, ParallelConfig,
  7. SchedulerConfig)
  8. from aphrodite.common.logger import init_logger
  9. from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
  10. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  11. from aphrodite.common.sequence import (SamplerOutput, SequenceData,
  12. SequenceGroupMetadata)
  13. from aphrodite.modeling.sampling_metadata import PersistentMetadata
  14. from aphrodite.common.utils import in_wsl
  15. logger = init_logger(__name__)
  16. KVCache = Tuple[torch.Tensor, torch.Tensor]
  17. _PAD_SLOT_ID = -1
  18. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  19. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  20. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
  21. class ModelRunner:
  22. def __init__(
  23. self,
  24. model_config: ModelConfig,
  25. parallel_config: ParallelConfig,
  26. scheduler_config: SchedulerConfig,
  27. ):
  28. self.model_config = model_config
  29. self.parallel_config = parallel_config
  30. self.scheduler_config = scheduler_config
  31. # model_config can be None in tests/samplers/test_sampler.py.
  32. # FIXME: This is a hack to make the tests work. Refactor this.
  33. self.sliding_window = (model_config.get_sliding_window()
  34. if model_config is not None else None)
  35. self.model = None
  36. self.block_size = None # Set after initial profiling.
  37. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  38. self.graph_memory_pool = None # Set during graph capture.
  39. self.max_context_len_to_capture = (
  40. self.model_config.max_context_len_to_capture
  41. if self.model_config is not None else 0)
  42. # When using CUDA graph, the input block tables must be padded to
  43. # max_context_len_to_capture. However, creating the block table in
  44. # Python can be expensive. To optimize this, we cache the block table
  45. # in numpy and only copy the actual input content at every iteration.
  46. # The shape of the cached block table will be
  47. # (max batch size to capture, max context len to capture / block size).
  48. self.graph_block_tables = None # Set after initial profiling.
  49. # cache in_wsl result
  50. self.in_wsl = in_wsl()
  51. def load_model(self) -> None:
  52. self.model = get_model(self.model_config)
  53. def set_block_size(self, block_size: int) -> None:
  54. self.block_size = block_size
  55. max_num_blocks = (self.max_context_len_to_capture + block_size -
  56. 1) // block_size
  57. self.graph_block_tables = np.zeros(
  58. (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
  59. def _prepare_prompt(
  60. self,
  61. seq_group_metadata_list: List[SequenceGroupMetadata],
  62. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
  63. assert len(seq_group_metadata_list) > 0
  64. input_tokens: List[List[int]] = []
  65. input_positions: List[List[int]] = []
  66. slot_mapping: List[List[int]] = []
  67. prompt_lens: List[int] = []
  68. for seq_group_metadata in seq_group_metadata_list:
  69. assert seq_group_metadata.is_prompt
  70. seq_ids = list(seq_group_metadata.seq_data.keys())
  71. assert len(seq_ids) == 1
  72. seq_id = seq_ids[0]
  73. seq_data = seq_group_metadata.seq_data[seq_id]
  74. prompt_tokens = seq_data.get_token_ids()
  75. prompt_len = len(prompt_tokens)
  76. prompt_lens.append(prompt_len)
  77. input_tokens.append(prompt_tokens)
  78. # NOTE: Here we assume that the first token in the prompt
  79. # is always the first token in the sequence.
  80. input_positions.append(list(range(prompt_len)))
  81. if seq_group_metadata.block_tables is None:
  82. # During memory profiling, the block tables are not initialized
  83. # yet. In this case, we just use a dummy slot mapping.
  84. slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
  85. continue
  86. # Compute the slot mapping.
  87. slot_mapping.append([])
  88. block_table = seq_group_metadata.block_tables[seq_id]
  89. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  90. # where start_idx is max(0, prompt_len - sliding_window).
  91. # For example, if the prompt len is 10, sliding window is 8, and
  92. # block size is 4, the first two tokens are masked and the slot
  93. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  94. start_idx = 0
  95. if self.sliding_window is not None:
  96. start_idx = max(0, prompt_len - self.sliding_window)
  97. for i in range(prompt_len):
  98. if i < start_idx:
  99. slot_mapping[-1].append(_PAD_SLOT_ID)
  100. continue
  101. block_number = block_table[i // self.block_size]
  102. block_offset = i % self.block_size
  103. slot = block_number * self.block_size + block_offset
  104. slot_mapping[-1].append(slot)
  105. max_prompt_len = max(prompt_lens)
  106. input_tokens = _make_tensor_with_pad(input_tokens,
  107. max_prompt_len,
  108. pad=0,
  109. dtype=torch.long)
  110. input_positions = _make_tensor_with_pad(input_positions,
  111. max_prompt_len,
  112. pad=0,
  113. dtype=torch.long)
  114. slot_mapping = _make_tensor_with_pad(slot_mapping,
  115. max_prompt_len,
  116. pad=_PAD_SLOT_ID,
  117. dtype=torch.long)
  118. input_metadata = InputMetadata(
  119. prompt_lens=prompt_lens,
  120. slot_mapping=slot_mapping,
  121. max_context_len=None,
  122. context_lens=None,
  123. block_tables=None,
  124. use_cuda_graph=False,
  125. )
  126. return input_tokens, input_positions, input_metadata
  127. def _prepare_decode(
  128. self,
  129. seq_group_metadata_list: List[SequenceGroupMetadata],
  130. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
  131. assert len(seq_group_metadata_list) > 0
  132. input_tokens: List[List[int]] = []
  133. input_positions: List[List[int]] = []
  134. slot_mapping: List[List[int]] = []
  135. context_lens: List[int] = []
  136. block_tables: List[List[int]] = []
  137. for seq_group_metadata in seq_group_metadata_list:
  138. assert not seq_group_metadata.is_prompt
  139. seq_ids = list(seq_group_metadata.seq_data.keys())
  140. for seq_id in seq_ids:
  141. seq_data = seq_group_metadata.seq_data[seq_id]
  142. generation_token = seq_data.get_last_token_id()
  143. input_tokens.append([generation_token])
  144. seq_len = seq_data.get_len()
  145. position = seq_len - 1
  146. input_positions.append([position])
  147. context_len = seq_len if self.sliding_window is None else min(
  148. seq_len, self.sliding_window)
  149. context_lens.append(context_len)
  150. block_table = seq_group_metadata.block_tables[seq_id]
  151. block_number = block_table[position // self.block_size]
  152. block_offset = position % self.block_size
  153. slot = block_number * self.block_size + block_offset
  154. slot_mapping.append([slot])
  155. if self.sliding_window is not None:
  156. sliding_window_blocks = (self.sliding_window //
  157. self.block_size)
  158. block_table = block_table[-sliding_window_blocks:]
  159. block_tables.append(block_table)
  160. batch_size = len(input_tokens)
  161. max_context_len = max(context_lens)
  162. use_captured_graph = (
  163. not self.model_config.enforce_eager
  164. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  165. and max_context_len <= self.max_context_len_to_capture)
  166. if use_captured_graph:
  167. # Pad the input tokens, positions, and slot mapping to match the
  168. # batch size of the captured graph.
  169. graph_batch_size = _get_graph_batch_size(batch_size)
  170. assert graph_batch_size >= batch_size
  171. for _ in range(graph_batch_size - batch_size):
  172. input_tokens.append([])
  173. input_positions.append([])
  174. slot_mapping.append([])
  175. context_lens.append(1)
  176. block_tables.append([])
  177. batch_size = graph_batch_size
  178. # When using CUDA graph, we don't need to make the tensors on the GPU
  179. # because they will be eventually copied to the designated GPU buffer.
  180. device = "cpu" if use_captured_graph else "cuda"
  181. pin_memory = use_captured_graph and not self.in_wsl
  182. input_tokens = _make_tensor_with_pad(input_tokens,
  183. max_len=1,
  184. pad=0,
  185. dtype=torch.long,
  186. device=device,
  187. pin_memory=pin_memory)
  188. input_positions = _make_tensor_with_pad(input_positions,
  189. max_len=1,
  190. pad=0,
  191. dtype=torch.long,
  192. device=device,
  193. pin_memory=pin_memory)
  194. slot_mapping = _make_tensor_with_pad(slot_mapping,
  195. max_len=1,
  196. pad=_PAD_SLOT_ID,
  197. dtype=torch.long,
  198. device=device,
  199. pin_memory=pin_memory)
  200. context_lens = torch.tensor(context_lens,
  201. dtype=torch.int,
  202. device=device,
  203. pin_memory=pin_memory)
  204. if use_captured_graph:
  205. # The shape of graph_block_tables is
  206. # [max batch size, max context len // block size].
  207. input_block_tables = self.graph_block_tables[:batch_size]
  208. for i, block_table in enumerate(block_tables):
  209. if block_table:
  210. input_block_tables[i, :len(block_table)] = block_table
  211. block_tables = torch.tensor(input_block_tables, device=device)
  212. else:
  213. block_tables = _make_tensor_with_pad(
  214. block_tables,
  215. max_len=max_context_len,
  216. pad=0,
  217. dtype=torch.int,
  218. )
  219. input_metadata = InputMetadata(
  220. prompt_lens=[],
  221. slot_mapping=slot_mapping,
  222. max_context_len=max_context_len,
  223. context_lens=context_lens,
  224. block_tables=block_tables,
  225. use_cuda_graph=use_captured_graph,
  226. )
  227. return input_tokens, input_positions, input_metadata
  228. def _prepare_sample(
  229. self,
  230. seq_group_metadata_list: List[SequenceGroupMetadata],
  231. prompt_lens: List[int],
  232. ) -> SamplingMetadata:
  233. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  234. selected_token_indices: List[int] = []
  235. selected_token_start_idx = 0
  236. categorized_sample_indices = {t: [] for t in SamplingType}
  237. categorized_sample_indices_start_idx = 0
  238. max_prompt_len = max(prompt_lens) if prompt_lens else 1
  239. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  240. seq_ids = list(seq_group_metadata.seq_data.keys())
  241. sampling_params = seq_group_metadata.sampling_params
  242. seq_groups.append((seq_ids, sampling_params))
  243. if seq_group_metadata.is_prompt:
  244. assert len(seq_ids) == 1
  245. prompt_len = prompt_lens[i]
  246. if sampling_params.prompt_logprobs is not None:
  247. # NOTE: prompt token positions do not need sample, skip
  248. categorized_sample_indices_start_idx += prompt_len - 1
  249. categorized_sample_indices[
  250. sampling_params.sampling_type].append(
  251. categorized_sample_indices_start_idx)
  252. categorized_sample_indices_start_idx += 1
  253. if sampling_params.prompt_logprobs is not None:
  254. selected_token_indices.extend(
  255. range(selected_token_start_idx,
  256. selected_token_start_idx + prompt_len - 1))
  257. selected_token_indices.append(selected_token_start_idx +
  258. prompt_len - 1)
  259. selected_token_start_idx += max_prompt_len
  260. else:
  261. num_seqs = len(seq_ids)
  262. selected_token_indices.extend(
  263. range(selected_token_start_idx,
  264. selected_token_start_idx + num_seqs))
  265. selected_token_start_idx += num_seqs
  266. categorized_sample_indices[
  267. sampling_params.sampling_type].extend(
  268. range(categorized_sample_indices_start_idx,
  269. categorized_sample_indices_start_idx + num_seqs))
  270. categorized_sample_indices_start_idx += num_seqs
  271. selected_token_indices = _async_h2d(selected_token_indices,
  272. dtype=torch.long,
  273. pin_memory=not self.in_wsl)
  274. categorized_sample_indices = {
  275. t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
  276. for t, seq_ids in categorized_sample_indices.items()
  277. }
  278. seq_data: Dict[int, SequenceData] = {}
  279. for seq_group_metadata in seq_group_metadata_list:
  280. seq_data.update(seq_group_metadata.seq_data)
  281. seq_persistence_data: Dict[int, dict] = {}
  282. for grp in seq_group_metadata_list:
  283. seq_persistence_data.update(grp.persistent_data)
  284. sampling_metadata = SamplingMetadata(
  285. seq_groups=seq_groups,
  286. seq_data=seq_data,
  287. prompt_lens=prompt_lens,
  288. selected_token_indices=selected_token_indices,
  289. categorized_sample_indices=categorized_sample_indices,
  290. persistent_metadata=PersistentMetadata(seq_persistence_data),
  291. )
  292. return sampling_metadata
  293. @torch.inference_mode()
  294. def execute_model(
  295. self,
  296. seq_group_metadata_list: List[SequenceGroupMetadata],
  297. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  298. ) -> SamplerOutput:
  299. # NOTE: We assume that all sequences in the group are all prompts or
  300. # all decodes.
  301. is_prompt = seq_group_metadata_list[0].is_prompt
  302. # Prepare input tensors.
  303. if is_prompt:
  304. inputs = self._prepare_prompt(seq_group_metadata_list)
  305. input_tokens, input_positions, input_metadata = inputs
  306. else:
  307. inputs = self._prepare_decode(seq_group_metadata_list)
  308. input_tokens, input_positions, input_metadata = inputs
  309. # Execute the model.
  310. if input_metadata.use_cuda_graph:
  311. graph_batch_size = input_tokens.shape[0]
  312. model_executable = self.graph_runners[graph_batch_size]
  313. else:
  314. model_executable = self.model
  315. hidden_states = model_executable(
  316. input_ids=input_tokens,
  317. positions=input_positions,
  318. kv_caches=kv_caches,
  319. input_metadata=input_metadata,
  320. )
  321. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  322. input_metadata.prompt_lens)
  323. # Sample the next token.
  324. output = self.model.sample(
  325. hidden_states=hidden_states,
  326. sampling_metadata=sampling_metadata,
  327. )
  328. return output
  329. @torch.inference_mode()
  330. def profile_run(self) -> None: # pylint: disable=useless-return
  331. # Enable top-k sampling to reflect the accurate memory usage.
  332. vocab_size = self.model_config.get_vocab_size()
  333. sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
  334. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  335. max_num_seqs = self.scheduler_config.max_num_seqs
  336. # Profile memory usage with max_num_sequences sequences and the total
  337. # number of tokens equal to max_num_batched_tokens.
  338. seqs: List[SequenceGroupMetadata] = []
  339. for group_id in range(max_num_seqs):
  340. seq_len = (max_num_batched_tokens // max_num_seqs +
  341. (group_id < max_num_batched_tokens % max_num_seqs))
  342. seq_data = SequenceData([0] * seq_len)
  343. seq = SequenceGroupMetadata(
  344. request_id=str(group_id),
  345. is_prompt=True,
  346. seq_data={group_id: seq_data},
  347. sampling_params=sampling_params,
  348. block_tables=None,
  349. persistent_data={},
  350. )
  351. seqs.append(seq)
  352. # Run the model with the dummy inputs.
  353. num_layers = self.model_config.get_num_layers(self.parallel_config)
  354. kv_caches = [(None, None)] * num_layers
  355. self.execute_model(seqs, kv_caches)
  356. torch.cuda.synchronize()
  357. return
  358. @torch.inference_mode()
  359. def capture_model(self, kv_caches: List[KVCache]) -> None:
  360. assert not self.model_config.enforce_eager
  361. logger.info("Capturing the model for CUDA graphs. This may lead to "
  362. "unexpected consequences if the model is not static. To "
  363. "run the model in eager mode, set 'enforce_eager=True' or "
  364. "use '--enforce-eager' in the CLI.")
  365. logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
  366. "per GPU. If you are running out of memory, consider "
  367. "decreasing `gpu_memory_utilization` or enforcing "
  368. "eager mode.")
  369. start_time = time.perf_counter()
  370. # Prepare dummy inputs. These will be reused for all batch sizes.
  371. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  372. input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
  373. input_positions = torch.zeros(max_batch_size, 1,
  374. dtype=torch.long).cuda()
  375. slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
  376. slot_mapping.fill_(_PAD_SLOT_ID)
  377. context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  378. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  379. # NOTE: Capturing the largest batch size first may help reduce the
  380. # memory usage of CUDA graph.
  381. for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
  382. # Create dummy input_metadata.
  383. input_metadata = InputMetadata(
  384. prompt_lens=[],
  385. slot_mapping=slot_mapping[:batch_size],
  386. max_context_len=self.max_context_len_to_capture,
  387. context_lens=context_lens[:batch_size],
  388. block_tables=block_tables[:batch_size],
  389. use_cuda_graph=True,
  390. )
  391. graph_runner = CUDAGraphRunner(self.model)
  392. graph_runner.capture(
  393. input_tokens[:batch_size],
  394. input_positions[:batch_size],
  395. kv_caches,
  396. input_metadata,
  397. memory_pool=self.graph_memory_pool,
  398. )
  399. self.graph_memory_pool = graph_runner.graph.pool()
  400. self.graph_runners[batch_size] = graph_runner
  401. end_time = time.perf_counter()
  402. elapsed_time = end_time - start_time
  403. # This usually takes < 10 seconds.
  404. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
  405. class CUDAGraphRunner:
  406. def __init__(self, model: nn.Module):
  407. self.model = model
  408. self.graph = None
  409. self.input_buffers: Dict[str, torch.Tensor] = {}
  410. self.output_buffers: Dict[str, torch.Tensor] = {}
  411. def capture( # pylint: disable=useless-return
  412. self,
  413. input_ids: torch.Tensor,
  414. positions: torch.Tensor,
  415. kv_caches: List[KVCache],
  416. input_metadata: InputMetadata,
  417. memory_pool,
  418. ) -> None:
  419. assert self.graph is None
  420. # Run the model once without capturing the graph.
  421. # This is to make sure that the captured graph does not include the
  422. # kernel launches for initial benchmarking (e.g., Triton autotune).
  423. self.model(
  424. input_ids,
  425. positions,
  426. kv_caches,
  427. input_metadata,
  428. )
  429. torch.cuda.synchronize()
  430. # Capture the graph.
  431. self.graph = torch.cuda.CUDAGraph()
  432. with torch.cuda.graph(self.graph, pool=memory_pool):
  433. hidden_states = self.model(
  434. input_ids,
  435. positions,
  436. kv_caches,
  437. input_metadata,
  438. )
  439. torch.cuda.synchronize()
  440. # Save the input and output buffers.
  441. self.input_buffers = {
  442. "input_ids": input_ids,
  443. "positions": positions,
  444. "kv_caches": kv_caches,
  445. "slot_mapping": input_metadata.slot_mapping,
  446. "context_lens": input_metadata.context_lens,
  447. "block_tables": input_metadata.block_tables,
  448. }
  449. self.output_buffers = {"hidden_states": hidden_states}
  450. return
  451. def forward(
  452. self,
  453. input_ids: torch.Tensor,
  454. positions: torch.Tensor,
  455. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  456. input_metadata: InputMetadata,
  457. ) -> torch.Tensor:
  458. # KV caches are fixed tensors, so we don't need to copy them.
  459. del kv_caches
  460. # Copy the input tensors to the input buffers.
  461. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  462. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  463. self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
  464. non_blocking=True)
  465. self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
  466. non_blocking=True)
  467. self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
  468. non_blocking=True)
  469. # Run the graph.
  470. self.graph.replay()
  471. # Return the output tensor.
  472. return self.output_buffers["hidden_states"]
  473. def __call__(self, *args, **kwargs):
  474. return self.forward(*args, **kwargs)
  475. def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
  476. assert len(x) <= max_len
  477. return x + [pad] * (max_len - len(x))
  478. def _make_tensor_with_pad(
  479. x: List[List[int]],
  480. max_len: int,
  481. pad: int,
  482. dtype: torch.dtype,
  483. device: Union[str, torch.device] = "cuda",
  484. pin_memory: bool = False,
  485. ) -> torch.Tensor:
  486. padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
  487. return torch.tensor(padded_x,
  488. dtype=dtype,
  489. device=device,
  490. pin_memory=pin_memory and str(device) == "cpu")
  491. def _get_graph_batch_size(batch_size: int) -> int:
  492. if batch_size <= 2:
  493. return batch_size
  494. elif batch_size <= 4:
  495. return 4
  496. else:
  497. return (batch_size + 7) // 8 * 8
  498. def _async_h2d(data: list, dtype, pin_memory):
  499. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
  500. return t.to(device="cuda", non_blocking=True)