123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574 |
- import time
- from typing import Dict, List, Tuple, Union
- import numpy as np
- import torch
- import torch.nn as nn
- from aphrodite.common.config import (ModelConfig, ParallelConfig,
- SchedulerConfig)
- from aphrodite.common.logger import init_logger
- from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
- from aphrodite.common.sampling_params import SamplingParams, SamplingType
- from aphrodite.common.sequence import (SamplerOutput, SequenceData,
- SequenceGroupMetadata)
- from aphrodite.modeling.sampling_metadata import PersistentMetadata
- from aphrodite.common.utils import in_wsl
- logger = init_logger(__name__)
- KVCache = Tuple[torch.Tensor, torch.Tensor]
- _PAD_SLOT_ID = -1
- # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
- # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
- _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
- class ModelRunner:
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- ):
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- # model_config can be None in tests/samplers/test_sampler.py.
- # FIXME: This is a hack to make the tests work. Refactor this.
- self.sliding_window = (model_config.get_sliding_window()
- if model_config is not None else None)
- self.model = None
- self.block_size = None # Set after initial profiling.
- self.graph_runners: Dict[int, CUDAGraphRunner] = {}
- self.graph_memory_pool = None # Set during graph capture.
- self.max_context_len_to_capture = (
- self.model_config.max_context_len_to_capture
- if self.model_config is not None else 0)
- # When using CUDA graph, the input block tables must be padded to
- # max_context_len_to_capture. However, creating the block table in
- # Python can be expensive. To optimize this, we cache the block table
- # in numpy and only copy the actual input content at every iteration.
- # The shape of the cached block table will be
- # (max batch size to capture, max context len to capture / block size).
- self.graph_block_tables = None # Set after initial profiling.
- # cache in_wsl result
- self.in_wsl = in_wsl()
- def load_model(self) -> None:
- self.model = get_model(self.model_config)
- def set_block_size(self, block_size: int) -> None:
- self.block_size = block_size
- max_num_blocks = (self.max_context_len_to_capture + block_size -
- 1) // block_size
- self.graph_block_tables = np.zeros(
- (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
- def _prepare_prompt(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[List[int]] = []
- input_positions: List[List[int]] = []
- slot_mapping: List[List[int]] = []
- prompt_lens: List[int] = []
- for seq_group_metadata in seq_group_metadata_list:
- assert seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- assert len(seq_ids) == 1
- seq_id = seq_ids[0]
- seq_data = seq_group_metadata.seq_data[seq_id]
- prompt_tokens = seq_data.get_token_ids()
- prompt_len = len(prompt_tokens)
- prompt_lens.append(prompt_len)
- input_tokens.append(prompt_tokens)
- # NOTE: Here we assume that the first token in the prompt
- # is always the first token in the sequence.
- input_positions.append(list(range(prompt_len)))
- if seq_group_metadata.block_tables is None:
- # During memory profiling, the block tables are not initialized
- # yet. In this case, we just use a dummy slot mapping.
- slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
- continue
- # Compute the slot mapping.
- slot_mapping.append([])
- block_table = seq_group_metadata.block_tables[seq_id]
- # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
- # where start_idx is max(0, prompt_len - sliding_window).
- # For example, if the prompt len is 10, sliding window is 8, and
- # block size is 4, the first two tokens are masked and the slot
- # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
- start_idx = 0
- if self.sliding_window is not None:
- start_idx = max(0, prompt_len - self.sliding_window)
- for i in range(prompt_len):
- if i < start_idx:
- slot_mapping[-1].append(_PAD_SLOT_ID)
- continue
- block_number = block_table[i // self.block_size]
- block_offset = i % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping[-1].append(slot)
- max_prompt_len = max(prompt_lens)
- input_tokens = _make_tensor_with_pad(input_tokens,
- max_prompt_len,
- pad=0,
- dtype=torch.long)
- input_positions = _make_tensor_with_pad(input_positions,
- max_prompt_len,
- pad=0,
- dtype=torch.long)
- slot_mapping = _make_tensor_with_pad(slot_mapping,
- max_prompt_len,
- pad=_PAD_SLOT_ID,
- dtype=torch.long)
- input_metadata = InputMetadata(
- prompt_lens=prompt_lens,
- slot_mapping=slot_mapping,
- max_context_len=None,
- context_lens=None,
- block_tables=None,
- use_cuda_graph=False,
- )
- return input_tokens, input_positions, input_metadata
- def _prepare_decode(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[List[int]] = []
- input_positions: List[List[int]] = []
- slot_mapping: List[List[int]] = []
- context_lens: List[int] = []
- block_tables: List[List[int]] = []
- for seq_group_metadata in seq_group_metadata_list:
- assert not seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- for seq_id in seq_ids:
- seq_data = seq_group_metadata.seq_data[seq_id]
- generation_token = seq_data.get_last_token_id()
- input_tokens.append([generation_token])
- seq_len = seq_data.get_len()
- position = seq_len - 1
- input_positions.append([position])
- context_len = seq_len if self.sliding_window is None else min(
- seq_len, self.sliding_window)
- context_lens.append(context_len)
- block_table = seq_group_metadata.block_tables[seq_id]
- block_number = block_table[position // self.block_size]
- block_offset = position % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append([slot])
- if self.sliding_window is not None:
- sliding_window_blocks = (self.sliding_window //
- self.block_size)
- block_table = block_table[-sliding_window_blocks:]
- block_tables.append(block_table)
- batch_size = len(input_tokens)
- max_context_len = max(context_lens)
- use_captured_graph = (
- not self.model_config.enforce_eager
- and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
- and max_context_len <= self.max_context_len_to_capture)
- if use_captured_graph:
- # Pad the input tokens, positions, and slot mapping to match the
- # batch size of the captured graph.
- graph_batch_size = _get_graph_batch_size(batch_size)
- assert graph_batch_size >= batch_size
- for _ in range(graph_batch_size - batch_size):
- input_tokens.append([])
- input_positions.append([])
- slot_mapping.append([])
- context_lens.append(1)
- block_tables.append([])
- batch_size = graph_batch_size
- # When using CUDA graph, we don't need to make the tensors on the GPU
- # because they will be eventually copied to the designated GPU buffer.
- device = "cpu" if use_captured_graph else "cuda"
- pin_memory = use_captured_graph and not self.in_wsl
- input_tokens = _make_tensor_with_pad(input_tokens,
- max_len=1,
- pad=0,
- dtype=torch.long,
- device=device,
- pin_memory=pin_memory)
- input_positions = _make_tensor_with_pad(input_positions,
- max_len=1,
- pad=0,
- dtype=torch.long,
- device=device,
- pin_memory=pin_memory)
- slot_mapping = _make_tensor_with_pad(slot_mapping,
- max_len=1,
- pad=_PAD_SLOT_ID,
- dtype=torch.long,
- device=device,
- pin_memory=pin_memory)
- context_lens = torch.tensor(context_lens,
- dtype=torch.int,
- device=device,
- pin_memory=pin_memory)
- if use_captured_graph:
- # The shape of graph_block_tables is
- # [max batch size, max context len // block size].
- input_block_tables = self.graph_block_tables[:batch_size]
- for i, block_table in enumerate(block_tables):
- if block_table:
- input_block_tables[i, :len(block_table)] = block_table
- block_tables = torch.tensor(input_block_tables, device=device)
- else:
- block_tables = _make_tensor_with_pad(
- block_tables,
- max_len=max_context_len,
- pad=0,
- dtype=torch.int,
- )
- input_metadata = InputMetadata(
- prompt_lens=[],
- slot_mapping=slot_mapping,
- max_context_len=max_context_len,
- context_lens=context_lens,
- block_tables=block_tables,
- use_cuda_graph=use_captured_graph,
- )
- return input_tokens, input_positions, input_metadata
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- ) -> SamplingMetadata:
- seq_groups: List[Tuple[List[int], SamplingParams]] = []
- selected_token_indices: List[int] = []
- selected_token_start_idx = 0
- categorized_sample_indices = {t: [] for t in SamplingType}
- categorized_sample_indices_start_idx = 0
- max_prompt_len = max(prompt_lens) if prompt_lens else 1
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- sampling_params = seq_group_metadata.sampling_params
- seq_groups.append((seq_ids, sampling_params))
- if seq_group_metadata.is_prompt:
- assert len(seq_ids) == 1
- prompt_len = prompt_lens[i]
- if sampling_params.prompt_logprobs is not None:
- # NOTE: prompt token positions do not need sample, skip
- categorized_sample_indices_start_idx += prompt_len - 1
- categorized_sample_indices[
- sampling_params.sampling_type].append(
- categorized_sample_indices_start_idx)
- categorized_sample_indices_start_idx += 1
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + prompt_len - 1))
- selected_token_indices.append(selected_token_start_idx +
- prompt_len - 1)
- selected_token_start_idx += max_prompt_len
- else:
- num_seqs = len(seq_ids)
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + num_seqs))
- selected_token_start_idx += num_seqs
- categorized_sample_indices[
- sampling_params.sampling_type].extend(
- range(categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx + num_seqs))
- categorized_sample_indices_start_idx += num_seqs
- selected_token_indices = _async_h2d(selected_token_indices,
- dtype=torch.long,
- pin_memory=not self.in_wsl)
- categorized_sample_indices = {
- t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
- for t, seq_ids in categorized_sample_indices.items()
- }
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
- seq_persistence_data: Dict[int, dict] = {}
- for grp in seq_group_metadata_list:
- seq_persistence_data.update(grp.persistent_data)
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- persistent_metadata=PersistentMetadata(seq_persistence_data),
- )
- return sampling_metadata
- @torch.inference_mode()
- def execute_model(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
- ) -> SamplerOutput:
- # NOTE: We assume that all sequences in the group are all prompts or
- # all decodes.
- is_prompt = seq_group_metadata_list[0].is_prompt
- # Prepare input tensors.
- if is_prompt:
- inputs = self._prepare_prompt(seq_group_metadata_list)
- input_tokens, input_positions, input_metadata = inputs
- else:
- inputs = self._prepare_decode(seq_group_metadata_list)
- input_tokens, input_positions, input_metadata = inputs
- # Execute the model.
- if input_metadata.use_cuda_graph:
- graph_batch_size = input_tokens.shape[0]
- model_executable = self.graph_runners[graph_batch_size]
- else:
- model_executable = self.model
- hidden_states = model_executable(
- input_ids=input_tokens,
- positions=input_positions,
- kv_caches=kv_caches,
- input_metadata=input_metadata,
- )
- sampling_metadata = self._prepare_sample(seq_group_metadata_list,
- input_metadata.prompt_lens)
- # Sample the next token.
- output = self.model.sample(
- hidden_states=hidden_states,
- sampling_metadata=sampling_metadata,
- )
- return output
- @torch.inference_mode()
- def profile_run(self) -> None: # pylint: disable=useless-return
- # Enable top-k sampling to reflect the accurate memory usage.
- vocab_size = self.model_config.get_vocab_size()
- sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
- max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
- max_num_seqs = self.scheduler_config.max_num_seqs
- # Profile memory usage with max_num_sequences sequences and the total
- # number of tokens equal to max_num_batched_tokens.
- seqs: List[SequenceGroupMetadata] = []
- for group_id in range(max_num_seqs):
- seq_len = (max_num_batched_tokens // max_num_seqs +
- (group_id < max_num_batched_tokens % max_num_seqs))
- seq_data = SequenceData([0] * seq_len)
- seq = SequenceGroupMetadata(
- request_id=str(group_id),
- is_prompt=True,
- seq_data={group_id: seq_data},
- sampling_params=sampling_params,
- block_tables=None,
- persistent_data={},
- )
- seqs.append(seq)
- # Run the model with the dummy inputs.
- num_layers = self.model_config.get_num_layers(self.parallel_config)
- kv_caches = [(None, None)] * num_layers
- self.execute_model(seqs, kv_caches)
- torch.cuda.synchronize()
- return
- @torch.inference_mode()
- def capture_model(self, kv_caches: List[KVCache]) -> None:
- assert not self.model_config.enforce_eager
- logger.info("Capturing the model for CUDA graphs. This may lead to "
- "unexpected consequences if the model is not static. To "
- "run the model in eager mode, set 'enforce_eager=True' or "
- "use '--enforce-eager' in the CLI.")
- logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
- "per GPU. If you are running out of memory, consider "
- "decreasing `gpu_memory_utilization` or enforcing "
- "eager mode.")
- start_time = time.perf_counter()
- # Prepare dummy inputs. These will be reused for all batch sizes.
- max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
- input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
- input_positions = torch.zeros(max_batch_size, 1,
- dtype=torch.long).cuda()
- slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
- slot_mapping.fill_(_PAD_SLOT_ID)
- context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
- block_tables = torch.from_numpy(self.graph_block_tables).cuda()
- # NOTE: Capturing the largest batch size first may help reduce the
- # memory usage of CUDA graph.
- for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
- # Create dummy input_metadata.
- input_metadata = InputMetadata(
- prompt_lens=[],
- slot_mapping=slot_mapping[:batch_size],
- max_context_len=self.max_context_len_to_capture,
- context_lens=context_lens[:batch_size],
- block_tables=block_tables[:batch_size],
- use_cuda_graph=True,
- )
- graph_runner = CUDAGraphRunner(self.model)
- graph_runner.capture(
- input_tokens[:batch_size],
- input_positions[:batch_size],
- kv_caches,
- input_metadata,
- memory_pool=self.graph_memory_pool,
- )
- self.graph_memory_pool = graph_runner.graph.pool()
- self.graph_runners[batch_size] = graph_runner
- end_time = time.perf_counter()
- elapsed_time = end_time - start_time
- # This usually takes < 10 seconds.
- logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
- class CUDAGraphRunner:
- def __init__(self, model: nn.Module):
- self.model = model
- self.graph = None
- self.input_buffers: Dict[str, torch.Tensor] = {}
- self.output_buffers: Dict[str, torch.Tensor] = {}
- def capture( # pylint: disable=useless-return
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[KVCache],
- input_metadata: InputMetadata,
- memory_pool,
- ) -> None:
- assert self.graph is None
- # Run the model once without capturing the graph.
- # This is to make sure that the captured graph does not include the
- # kernel launches for initial benchmarking (e.g., Triton autotune).
- self.model(
- input_ids,
- positions,
- kv_caches,
- input_metadata,
- )
- torch.cuda.synchronize()
- # Capture the graph.
- self.graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self.graph, pool=memory_pool):
- hidden_states = self.model(
- input_ids,
- positions,
- kv_caches,
- input_metadata,
- )
- torch.cuda.synchronize()
- # Save the input and output buffers.
- self.input_buffers = {
- "input_ids": input_ids,
- "positions": positions,
- "kv_caches": kv_caches,
- "slot_mapping": input_metadata.slot_mapping,
- "context_lens": input_metadata.context_lens,
- "block_tables": input_metadata.block_tables,
- }
- self.output_buffers = {"hidden_states": hidden_states}
- return
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
- input_metadata: InputMetadata,
- ) -> torch.Tensor:
- # KV caches are fixed tensors, so we don't need to copy them.
- del kv_caches
- # Copy the input tensors to the input buffers.
- self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
- self.input_buffers["positions"].copy_(positions, non_blocking=True)
- self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
- non_blocking=True)
- self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
- non_blocking=True)
- self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
- non_blocking=True)
- # Run the graph.
- self.graph.replay()
- # Return the output tensor.
- return self.output_buffers["hidden_states"]
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
- def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
- assert len(x) <= max_len
- return x + [pad] * (max_len - len(x))
- def _make_tensor_with_pad(
- x: List[List[int]],
- max_len: int,
- pad: int,
- dtype: torch.dtype,
- device: Union[str, torch.device] = "cuda",
- pin_memory: bool = False,
- ) -> torch.Tensor:
- padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
- return torch.tensor(padded_x,
- dtype=dtype,
- device=device,
- pin_memory=pin_memory and str(device) == "cpu")
- def _get_graph_batch_size(batch_size: int) -> int:
- if batch_size <= 2:
- return batch_size
- elif batch_size <= 4:
- return 4
- else:
- return (batch_size + 7) // 8 * 8
- def _async_h2d(data: list, dtype, pin_memory):
- t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
- return t.to(device="cuda", non_blocking=True)
|