123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011 |
- import contextlib
- import time
- from typing import Dict, List, Optional, Tuple, Set, Union
- import numpy as np
- import torch
- import torch.nn as nn
- from loguru import logger
- from aphrodite.common.config import (
- DeviceConfig,
- ModelConfig,
- LoRAConfig,
- ParallelConfig,
- SchedulerConfig,
- )
- from aphrodite.common.logger import get_loading_progress_bar
- from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
- from aphrodite.modeling.megatron import cupy_utils
- from aphrodite.modeling.megatron.communication_op import broadcast_tensor_dict
- from aphrodite.modeling.megatron.parallel_state import (
- get_tensor_model_parallel_world_size,
- with_cupy_nccl_for_all_reduce,
- )
- from aphrodite.modeling.megatron import custom_all_reduce
- from aphrodite.common.sampling_params import SamplingParams, SamplingType
- from aphrodite.common.sequence import (
- SamplerOutput,
- SequenceData,
- SequenceGroupMetadata,
- )
- from aphrodite.modeling.sampling_metadata import PersistentMetadata
- from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
- from aphrodite.lora.layers import LoRAMapping
- from aphrodite.lora.request import LoRARequest
- from aphrodite.common.utils import in_wsl, measure_cuda_memory
- KVCache = Tuple[torch.Tensor, torch.Tensor]
- _PAD_SLOT_ID = -1
- LORA_WARMUP_RANK = 8
- # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
- # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
- _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
- class ModelRunner:
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- kv_cache_dtype: Optional[str] = "auto",
- kv_quant_params_path: Optional[str] = None,
- is_driver_worker: bool = False,
- ):
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.lora_config = lora_config
- self.is_driver_worker = is_driver_worker
- # model_config can be None in tests/samplers/test_sampler.py.
- # FIXME: This is a hack to make the tests work. Refactor this.
- self.sliding_window = (model_config.get_sliding_window()
- if model_config is not None else None)
- self.device_config = (device_config
- if device_config is not None else DeviceConfig())
- self.device = self.device_config.device
- self.model = None
- self.block_size = None # Set after initial profiling.
- self.lora_manager = None
- self.graph_runners: Dict[int, CUDAGraphRunner] = {}
- self.graph_memory_pool = None # Set during graph capture.
- self.max_context_len_to_capture = (
- self.model_config.max_context_len_to_capture
- if self.model_config is not None else 0)
- # When using CUDA graph, the input block tables must be padded to
- # max_context_len_to_capture. However, creating the block table in
- # Python can be expensive. To optimize this, we cache the block table
- # in numpy and only copy the actual input content at every iteration.
- # The shape of the cached block table will be
- # (max batch size to capture, max context len to capture / block size).
- self.graph_block_tables = None # Set after initial profiling.
- # cache in_wsl result
- self.in_wsl = in_wsl()
- self.kv_cache_dtype = kv_cache_dtype
- self.kv_quant_params = (self.load_kv_quant_params(
- model_config, kv_quant_params_path)
- if self.kv_cache_dtype == "int8" else None)
- def load_kv_quant_params(self, model_config: ModelConfig,
- kv_quant_params_path: str) -> List[List[float]]:
- if model_config is None:
- return None
- # Remove it when all models support kv cache int8.
- architectures = model_config.hf_config.architectures
- for arch in architectures:
- if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
- raise ValueError(
- "KV CACHE INT8 is not supported for model architectures "
- f"{arch} for now. "
- "Supported architectures: LlamaForCausalLM and "
- "LLaMAForCausalLM.")
- num_layers = model_config.hf_config.num_hidden_layers
- kv_quant_params = []
- for i in range(num_layers):
- if kv_quant_params_path is not None:
- path = (kv_quant_params_path +
- f"/layers.{i}.past_kv_scale.0.weight")
- kv_quant_param = list(np.fromfile(path, dtype=np.float32))
- kv_quant_params.append(kv_quant_param)
- return kv_quant_params
- def load_model(self) -> None:
- with measure_cuda_memory() as m:
- self.model = get_model(self.model_config, self.device_config,
- self.lora_config)
- self.model_memory_usage = m.consumed_memory
- tp = get_tensor_model_parallel_world_size()
- logger.info(
- "Model weights loaded. Memory usage: "
- f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
- f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
- vocab_size = self.model.config.vocab_size
- if self.lora_config:
- assert (hasattr(self.model, "supported_lora_modules")
- and self.model.supported_lora_modules
- ), "Model does not support LoRA"
- assert hasattr(
- self.model,
- "embedding_modules"), "Model does not have embedding_modules"
- assert hasattr(self.model, "embedding_padding_modules"
- ), "Model does not have embedding_padding_modules"
- self.lora_manager = LRUCacheWorkerLoRAManager(
- self.scheduler_config.max_num_seqs,
- self.scheduler_config.max_num_batched_tokens +
- self.scheduler_config.max_paddings,
- vocab_size,
- self.lora_config,
- self.device,
- self.model.embedding_modules,
- self.model.embedding_padding_modules,
- )
- self.model = self.lora_manager.create_lora_manager(self.model)
- def set_block_size(self, block_size: int) -> None:
- self.block_size = block_size
- max_num_blocks = (self.max_context_len_to_capture + block_size -
- 1) // block_size
- self.graph_block_tables = np.zeros(
- (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
- def _prepare_prompt(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
- List[int], List[int], Set[LoRARequest], ]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[List[int]] = []
- input_positions: List[List[int]] = []
- slot_mapping: List[List[int]] = []
- lora_index_mapping: List[int] = []
- lora_prompt_mapping: List[int] = []
- lora_requests: Set[LoRARequest] = set()
- prompt_lens: List[int] = []
- context_lens: List[int] = []
- subquery_lens: List[int] = []
- prefix_block_tables: List[List[int]] = []
- for seq_group_metadata in seq_group_metadata_list:
- assert seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- assert len(seq_ids) == 1
- seq_id = seq_ids[0]
- seq_data = seq_group_metadata.seq_data[seq_id]
- prompt_tokens = seq_data.get_token_ids()
- prompt_len = len(prompt_tokens)
- prompt_lens.append(prompt_len)
- computed_len = 0
- # NOTE: This only works for oooooooxxx style attention.
- computed_block_nums = seq_group_metadata.computed_block_nums
- if (computed_block_nums is not None
- and len(computed_block_nums) > 0
- and self.sliding_window is None):
- # Prefix is not supported with sliding_window
- computed_len = len(computed_block_nums) * self.block_size
- prompt_tokens = prompt_tokens[computed_len:]
- prefix_block_tables.append(computed_block_nums)
- else:
- prefix_block_tables.append([])
- # actual prompt lens
- context_lens.append(computed_len)
- subquery_lens.append(prompt_len - computed_len)
- input_tokens.append(prompt_tokens)
- # NOTE: Here we assume that the first token in the prompt
- # is always the first token in the sequence.
- input_positions.append(
- list(range(computed_len, computed_len + len(prompt_tokens))))
- lora_id = seq_group_metadata.lora_int_id
- if lora_id > 0:
- lora_requests.add(seq_group_metadata.lora_request)
- lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
- lora_prompt_mapping.extend(
- [lora_id] *
- (prompt_len - computed_len
- if seq_group_metadata.sampling_params.prompt_logprobs else 1))
- if seq_group_metadata.block_tables is None:
- # During memory profiling, the block tables are not initialized
- # yet. In this case, we just use a dummy slot mapping.
- slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
- continue
- # Compute the slot mapping.
- slot_mapping.append([])
- block_table = seq_group_metadata.block_tables[seq_id]
- # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
- # where start_idx is max(0, prompt_len - sliding_window).
- # For example, if the prompt len is 10, sliding window is 8, and
- # block size is 4, the first two tokens are masked and the slot
- # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
- start_idx = 0
- if self.sliding_window is not None:
- assert computed_len == 0, (
- "Prefix caching is currently not supported with "
- "sliding window attention")
- start_idx = max(0, prompt_len - self.sliding_window)
- for i in range(computed_len, prompt_len):
- if i < start_idx:
- slot_mapping[-1].append(_PAD_SLOT_ID)
- continue
- block_number = block_table[i // self.block_size]
- block_offset = i % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping[-1].append(slot)
- max_prompt_len = max(subquery_lens)
- input_tokens = _make_tensor_with_pad(
- input_tokens,
- max_prompt_len,
- pad=0,
- dtype=torch.long,
- device=self.device,
- )
- input_positions = _make_tensor_with_pad(
- input_positions,
- max_prompt_len,
- pad=0,
- dtype=torch.long,
- device=self.device,
- )
- slot_mapping = _make_tensor_with_pad(
- slot_mapping,
- max_prompt_len,
- pad=_PAD_SLOT_ID,
- dtype=torch.long,
- device=self.device,
- )
- lora_index_mapping = [
- _pad_to_max(mapping, max_prompt_len, pad=0)
- for mapping in lora_index_mapping
- ]
- context_lens_tensor = torch.tensor(context_lens,
- dtype=torch.int,
- device=self.device)
- # Prepare prefix block tables
- max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
- block_tables = _make_tensor_with_pad(
- prefix_block_tables,
- max_len=max_prompt_block_table_len,
- pad=0,
- dtype=torch.int,
- device=self.device,
- )
- start_loc_tensor = torch.arange(
- 0,
- len(prompt_lens) * max_prompt_len,
- max_prompt_len,
- dtype=torch.long,
- device=self.device,
- )
- prompt_lens_tensor = torch.tensor(prompt_lens,
- dtype=torch.long,
- device=self.device)
- input_metadata = InputMetadata(
- is_prompt=True,
- slot_mapping=slot_mapping,
- prompt_lens=prompt_lens_tensor,
- max_seq_len=max_prompt_len,
- start_loc=start_loc_tensor,
- max_context_len=None,
- context_lens=context_lens_tensor,
- block_tables=block_tables,
- use_cuda_graph=False,
- kv_cache_dtype=self.kv_cache_dtype,
- kv_quant_params=self.kv_quant_params,
- )
- return (
- input_tokens,
- input_positions,
- input_metadata,
- prompt_lens,
- subquery_lens,
- lora_index_mapping,
- lora_prompt_mapping,
- lora_requests,
- )
- def _prepare_decode(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
- Set[LoRARequest], ]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[List[int]] = []
- input_positions: List[List[int]] = []
- slot_mapping: List[List[int]] = []
- context_lens: List[int] = []
- block_tables: List[List[int]] = []
- lora_index_mapping: List[int] = []
- lora_prompt_mapping: List[int] = []
- lora_requests: Set[LoRARequest] = set()
- for seq_group_metadata in seq_group_metadata_list:
- assert not seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- lora_id = seq_group_metadata.lora_int_id
- if lora_id > 0:
- lora_requests.add(seq_group_metadata.lora_request)
- for seq_id in seq_ids:
- seq_data = seq_group_metadata.seq_data[seq_id]
- generation_token = seq_data.get_last_token_id()
- input_tokens.append([generation_token])
- seq_len = seq_data.get_len()
- position = seq_len - 1
- input_positions.append([position])
- context_len = (seq_len if self.sliding_window is None else min(
- seq_len, self.sliding_window))
- context_lens.append(context_len)
- block_table = seq_group_metadata.block_tables[seq_id]
- block_number = block_table[position // self.block_size]
- block_offset = position % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append([slot])
- lora_index_mapping.append([lora_id])
- lora_prompt_mapping.append(lora_id)
- if self.sliding_window is not None:
- sliding_window_blocks = (self.sliding_window //
- self.block_size)
- block_table = block_table[-sliding_window_blocks:]
- block_tables.append(block_table)
- batch_size = len(input_tokens)
- max_context_len = max(context_lens)
- use_captured_graph = (
- not self.model_config.enforce_eager
- and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
- and max_context_len <= self.max_context_len_to_capture)
- if use_captured_graph:
- # Pad the input tokens, positions, and slot mapping to match the
- # batch size of the captured graph.
- graph_batch_size = _get_graph_batch_size(batch_size)
- assert graph_batch_size >= batch_size
- for _ in range(graph_batch_size - batch_size):
- input_tokens.append([])
- input_positions.append([])
- slot_mapping.append([])
- context_lens.append(1)
- block_tables.append([])
- batch_size = graph_batch_size
- input_tokens = _make_tensor_with_pad(input_tokens,
- max_len=1,
- pad=0,
- dtype=torch.long,
- device=self.device)
- input_positions = _make_tensor_with_pad(
- input_positions,
- max_len=1,
- pad=0,
- dtype=torch.long,
- device=self.device,
- )
- slot_mapping = _make_tensor_with_pad(
- slot_mapping,
- max_len=1,
- pad=_PAD_SLOT_ID,
- dtype=torch.long,
- device=self.device,
- )
- context_lens = torch.tensor(context_lens,
- dtype=torch.int,
- device=self.device)
- if use_captured_graph:
- # The shape of graph_block_tables is
- # [max batch size, max context len // block size].
- input_block_tables = self.graph_block_tables[:batch_size]
- for i, block_table in enumerate(block_tables):
- if block_table:
- input_block_tables[i, :len(block_table)] = block_table
- block_tables = torch.tensor(input_block_tables, device=self.device)
- else:
- max_block_table_len = max(
- len(block_table) for block_table in block_tables)
- block_tables = _make_tensor_with_pad(
- block_tables,
- max_len=max_block_table_len,
- pad=0,
- dtype=torch.int,
- device=self.device,
- )
- lora_index_mapping = [
- _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
- ]
- input_metadata = InputMetadata(
- is_prompt=False,
- slot_mapping=slot_mapping,
- prompt_lens=None,
- max_seq_len=None,
- start_loc=None,
- max_context_len=max_context_len,
- context_lens=context_lens,
- block_tables=block_tables,
- use_cuda_graph=use_captured_graph,
- kv_cache_dtype=self.kv_cache_dtype,
- kv_quant_params=self.kv_quant_params,
- )
- return (
- input_tokens,
- input_positions,
- input_metadata,
- lora_index_mapping,
- lora_prompt_mapping,
- lora_requests,
- )
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- subquery_lens: Optional[List[int]],
- ) -> SamplingMetadata:
- seq_groups: List[Tuple[List[int], SamplingParams]] = []
- selected_token_indices: List[int] = []
- generators: List[torch.Generator] = []
- selected_token_start_idx = 0
- categorized_sample_indices = {t: [] for t in SamplingType}
- categorized_sample_indices_start_idx = 0
- max_subquery_len = max(subquery_lens) if subquery_lens else 1
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- sampling_params = seq_group_metadata.sampling_params
- seq_groups.append((seq_ids, sampling_params))
- if seq_group_metadata.is_prompt:
- assert len(seq_ids) == 1
- assert subquery_lens is not None
- subquery_len = subquery_lens[i]
- if sampling_params.prompt_logprobs is not None:
- # NOTE: prompt token positions do not need sample, skip
- categorized_sample_indices_start_idx += subquery_len - 1
- categorized_sample_indices[
- sampling_params.sampling_type].append(
- categorized_sample_indices_start_idx)
- categorized_sample_indices_start_idx += 1
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(
- selected_token_start_idx,
- selected_token_start_idx + subquery_len - 1,
- ))
- selected_token_indices.append(selected_token_start_idx +
- subquery_len - 1)
- selected_token_start_idx += max_subquery_len
- if sampling_params.seed is not None:
- seq_group_metadata.state.generator = torch.Generator(
- device="cuda").manual_seed(sampling_params.seed)
- else:
- num_seqs = len(seq_ids)
- selected_token_indices.extend(
- range(
- selected_token_start_idx,
- selected_token_start_idx + num_seqs,
- ))
- selected_token_start_idx += num_seqs
- categorized_sample_indices[
- sampling_params.sampling_type].extend(
- range(
- categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx + num_seqs,
- ))
- categorized_sample_indices_start_idx += num_seqs
- if sampling_params.seed is not None:
- generators.append(seq_group_metadata.state.generator)
- selected_token_indices = _async_h2d(
- selected_token_indices,
- dtype=torch.long,
- target_device=self.device,
- pin_memory=not self.in_wsl,
- )
- categorized_sample_indices = {
- t: _async_h2d(
- seq_ids,
- dtype=torch.int,
- target_device=self.device,
- pin_memory=not self.in_wsl,
- )
- for t, seq_ids in categorized_sample_indices.items()
- }
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
- seq_persistence_data: Dict[int, dict] = {}
- for grp in seq_group_metadata_list:
- seq_persistence_data.update(grp.persistent_data)
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- generators=generators,
- persistent_metadata=PersistentMetadata(seq_persistence_data),
- )
- return sampling_metadata
- def prepare_input_tensors(
- self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
- Set[int], LoRAMapping, ]:
- if self.is_driver_worker:
- # NOTE: We assume that all sequences in the group are all prompts or
- # all decodes.
- is_prompt = seq_group_metadata_list[0].is_prompt
- # Prepare input tensors.
- if is_prompt:
- (
- input_tokens,
- input_positions,
- input_metadata,
- prompt_lens,
- subquery_lens,
- lora_index_mapping,
- lora_prompt_mapping,
- lora_requests,
- ) = self._prepare_prompt(seq_group_metadata_list)
- else:
- (
- input_tokens,
- input_positions,
- input_metadata,
- lora_index_mapping,
- lora_prompt_mapping,
- lora_requests,
- ) = self._prepare_decode(seq_group_metadata_list)
- prompt_lens = []
- subquery_lens = None
- sampling_metadata = self._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens)
- if self.lora_config:
- flat_lora_index_mapping = [
- item for sublist in lora_index_mapping for item in sublist
- ]
- lora_mapping = LoRAMapping(
- flat_lora_index_mapping,
- lora_prompt_mapping,
- )
- else:
- lora_mapping = None
- # Broadcast the metadata.
- metadata_dict = {
- "input_tokens": input_tokens,
- "input_positions": input_positions,
- "is_prompt": input_metadata.is_prompt,
- "slot_mapping": input_metadata.slot_mapping,
- "prompt_lens": input_metadata.prompt_lens,
- "max_seq_len": input_metadata.max_seq_len,
- "start_loc": input_metadata.start_loc,
- "max_context_len": input_metadata.max_context_len,
- "context_lens": input_metadata.context_lens,
- "block_tables": input_metadata.block_tables,
- "use_cuda_graph": input_metadata.use_cuda_graph,
- "kv_cache_dtype": input_metadata.kv_cache_dtype,
- "kv_quant_params": input_metadata.kv_quant_params,
- "selected_token_indices":
- sampling_metadata.selected_token_indices, # noqa
- "lora_requests": lora_requests,
- "lora_mapping": lora_mapping,
- }
- broadcast_tensor_dict(metadata_dict, src=0)
- else:
- metadata_dict = broadcast_tensor_dict(src=0)
- input_tokens = metadata_dict["input_tokens"]
- input_positions = metadata_dict["input_positions"]
- lora_mapping = metadata_dict["lora_mapping"]
- lora_requests = metadata_dict["lora_requests"]
- input_metadata = InputMetadata(
- is_prompt=metadata_dict["is_prompt"],
- slot_mapping=metadata_dict["slot_mapping"],
- prompt_lens=metadata_dict["prompt_lens"],
- max_seq_len=metadata_dict["max_seq_len"],
- start_loc=metadata_dict["start_loc"],
- max_context_len=metadata_dict["max_context_len"],
- context_lens=metadata_dict["context_lens"],
- block_tables=metadata_dict["block_tables"],
- use_cuda_graph=metadata_dict["use_cuda_graph"],
- kv_cache_dtype=metadata_dict["kv_cache_dtype"],
- kv_quant_params=metadata_dict["kv_quant_params"],
- )
- sampling_metadata = SamplingMetadata(
- seq_groups=None,
- seq_data=None,
- prompt_lens=None,
- selected_token_indices=metadata_dict["selected_token_indices"],
- categorized_sample_indices=None,
- generators=None,
- perform_sampling=False,
- )
- return (
- input_tokens,
- input_positions,
- input_metadata,
- sampling_metadata,
- lora_requests,
- lora_mapping,
- )
- @torch.inference_mode()
- def execute_model(
- self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
- ) -> Optional[SamplerOutput]:
- (
- input_tokens,
- input_positions,
- input_metadata,
- sampling_metadata,
- lora_requests,
- lora_mapping,
- ) = self.prepare_input_tensors(seq_group_metadata_list)
- if self.lora_config:
- self.set_active_loras(lora_requests, lora_mapping)
- # Execute the model.
- if input_metadata.use_cuda_graph:
- graph_batch_size = input_tokens.shape[0]
- model_executable = self.graph_runners[graph_batch_size]
- else:
- model_executable = self.model
- hidden_states = model_executable(
- input_ids=input_tokens,
- positions=input_positions,
- kv_caches=kv_caches,
- input_metadata=input_metadata,
- )
- # Sample the next token.
- output = self.model.sample(
- hidden_states=hidden_states,
- sampling_metadata=sampling_metadata,
- )
- return output
- @torch.inference_mode()
- def profile_run(self) -> None:
- # Enable top-k sampling to reflect the accurate memory usage.
- vocab_size = self.model_config.get_vocab_size()
- sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
- max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
- max_num_seqs = self.scheduler_config.max_num_seqs
- # This represents the maximum number of different requests
- # that will have unique loras, an therefore the max amount of memory
- # consumption create dummy lora request copies from the lora request
- # passed in, which contains a lora from the lora warmup path.
- dummy_lora_requests = []
- dummy_lora_requests_per_seq = []
- if self.lora_config:
- for idx in range(self.lora_config.max_loras):
- lora_id = idx + 1
- dummy_lora_request = LoRARequest(
- lora_name=f"warmup_{lora_id}",
- lora_int_id=lora_id,
- lora_local_path="/not/a/real/path",
- )
- self.lora_manager.add_dummy_lora(dummy_lora_request,
- rank=LORA_WARMUP_RANK)
- dummy_lora_requests.append(dummy_lora_request)
- dummy_lora_requests_per_seq = [
- dummy_lora_requests[idx % len(dummy_lora_requests)]
- for idx in range(max_num_seqs)
- ]
- # Profile memory usage with max_num_sequences sequences and the total
- # number of tokens equal to max_num_batched_tokens.
- seqs: List[SequenceGroupMetadata] = []
- for group_id in range(max_num_seqs):
- seq_len = max_num_batched_tokens // max_num_seqs + (
- group_id < max_num_batched_tokens % max_num_seqs)
- seq_data = SequenceData([0] * seq_len)
- seq = SequenceGroupMetadata(
- request_id=str(group_id),
- is_prompt=True,
- seq_data={group_id: seq_data},
- sampling_params=sampling_params,
- block_tables=None,
- persistent_data={},
- lora_request=dummy_lora_requests_per_seq[group_id]
- if dummy_lora_requests_per_seq else None,
- )
- seqs.append(seq)
- # Run the model with the dummy inputs.
- num_layers = self.model_config.get_num_layers(self.parallel_config)
- kv_caches = [(None, None)] * num_layers
- self.execute_model(seqs, kv_caches)
- torch.cuda.synchronize()
- return
- def remove_all_loras(self) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.remove_all_loras()
- def set_active_loras(self, lora_requests: List[LoRARequest],
- lora_mapping: LoRAMapping) -> None:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- self.lora_manager.set_active_loras(lora_requests, lora_mapping)
- def add_lora(self, lora_request: LoRARequest) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.add_lora(lora_request)
- def remove_lora(self, lora_id: int) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.remove_lora(lora_id)
- def list_loras(self) -> Set[int]:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.list_loras()
- @torch.inference_mode()
- def capture_model(self, kv_caches: List[KVCache]) -> None:
- # NOTE: This is a hack to ensure that the NCCL backend is never
- # deleted before the CUDA graph
- self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
- assert not self.model_config.enforce_eager
- logger.info("Capturing the model for CUDA graphs. This may lead to "
- "unexpected consequences if the model is not static. To "
- "run the model in eager mode, set 'enforce_eager=True' or "
- "use '--enforce-eager' in the CLI.")
- logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
- "per GPU. If you are running out of memory, consider "
- "decreasing `gpu_memory_utilization` or enforcing "
- "eager mode.")
- start_time = time.perf_counter()
- # Prepare dummy inputs. These will be reused for all batch sizes.
- max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
- input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
- input_positions = torch.zeros(max_batch_size, 1,
- dtype=torch.long).cuda()
- slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
- slot_mapping.fill_(_PAD_SLOT_ID)
- context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
- block_tables = torch.from_numpy(self.graph_block_tables).cuda()
- graph_batch_size = _get_graph_batch_size(
- self.scheduler_config.max_num_seqs)
- batch_size_capture_list = [
- bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
- ]
- # NOTE: Capturing the largest batch size first may help reduce the
- # memory usage of CUDA graph.
- # NOTE: There are 3 backends for all-reduce: custom all-reduce
- # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
- # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
- # graph, we use either custom all-reduce kernel or PyTorch NCCL.
- # We always prioritize using custom all-reduce kernel but fall back
- # to PyTorch or CuPy NCCL if it is disabled or not supported.
- # Initialize a new progress bar
- progress = get_loading_progress_bar()
- task = progress.add_task("[cyan]Capturing graph...",
- total=len(batch_size_capture_list))
- with progress, custom_all_reduce.capture():
- for batch_size in reversed(batch_size_capture_list):
- if batch_size > self.scheduler_config.max_num_seqs:
- continue
- # Create dummy input_metadata.
- input_metadata = InputMetadata(
- is_prompt=False,
- slot_mapping=slot_mapping[:batch_size],
- prompt_lens=None,
- max_seq_len=None,
- start_loc=None,
- max_context_len=self.max_context_len_to_capture,
- context_lens=context_lens[:batch_size],
- block_tables=block_tables[:batch_size],
- use_cuda_graph=True,
- kv_cache_dtype=self.kv_cache_dtype,
- kv_quant_params=self.kv_quant_params,
- )
- if self.lora_config:
- lora_mapping = LoRAMapping(
- [0] * batch_size,
- [0] * batch_size,
- )
- self.set_active_loras(set(), lora_mapping)
- graph_runner = CUDAGraphRunner(self.model)
- graph_runner.capture(
- input_tokens[:batch_size],
- input_positions[:batch_size],
- kv_caches,
- input_metadata,
- memory_pool=self.graph_memory_pool,
- )
- self.graph_memory_pool = graph_runner.graph.pool()
- self.graph_runners[batch_size] = graph_runner
- # Update the progress bar
- progress.update(task, advance=1)
- end_time = time.perf_counter()
- elapsed_time = end_time - start_time
- # This usually takes < 10 seconds.
- logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
- def __del__(self) -> None:
- # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
- # NOTE: This is necessary because otherwise deadlocks can
- # happen.
- # FIXME: This is a bit hacky. Find a more robust solution.
- self.graph_runners.clear()
- self.cupy_nccl_backend = None
- class CUDAGraphRunner:
- def __init__(self, model: nn.Module):
- self.model = model
- self.graph = None
- self.input_buffers: Dict[str, torch.Tensor] = {}
- self.output_buffers: Dict[str, torch.Tensor] = {}
- def capture(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[KVCache],
- input_metadata: InputMetadata,
- memory_pool,
- ) -> None:
- assert self.graph is None
- # Run the model once without capturing the graph.
- # This is to make sure that the captured graph does not include the
- # kernel launches for initial benchmarking (e.g., Triton autotune).
- with _maybe_cupy_nccl():
- self.model(
- input_ids,
- positions,
- kv_caches,
- input_metadata,
- )
- torch.cuda.synchronize()
- # Capture the graph.
- # NOTE: Python 3.8 does not support multi-line with statements.
- # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
- self.graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self.graph,
- pool=memory_pool), _maybe_cupy_nccl():
- hidden_states = self.model(
- input_ids,
- positions,
- kv_caches,
- input_metadata,
- )
- torch.cuda.synchronize()
- # Save the input and output buffers.
- self.input_buffers = {
- "input_ids": input_ids,
- "positions": positions,
- "kv_caches": kv_caches,
- "slot_mapping": input_metadata.slot_mapping,
- "context_lens": input_metadata.context_lens,
- "block_tables": input_metadata.block_tables,
- }
- self.output_buffers = {"hidden_states": hidden_states}
- return
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
- input_metadata: InputMetadata,
- ) -> torch.Tensor:
- # KV caches are fixed tensors, so we don't need to copy them.
- del kv_caches
- # Copy the input tensors to the input buffers.
- self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
- self.input_buffers["positions"].copy_(positions, non_blocking=True)
- self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
- non_blocking=True)
- self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
- non_blocking=True)
- self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
- non_blocking=True)
- # Run the graph.
- self.graph.replay()
- # Return the output tensor.
- return self.output_buffers["hidden_states"]
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
- @contextlib.contextmanager
- def _maybe_cupy_nccl():
- if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
- with with_cupy_nccl_for_all_reduce():
- yield
- else:
- yield
- def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
- assert len(x) <= max_len
- return x + [pad] * (max_len - len(x))
- def _make_tensor_with_pad(
- x: List[List[int]],
- max_len: int,
- pad: int,
- dtype: torch.dtype,
- device: Optional[Union[str, torch.device]],
- ) -> torch.Tensor:
- padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
- return torch.tensor(padded_x, dtype=dtype, device=device)
- def _get_graph_batch_size(batch_size: int) -> int:
- if batch_size <= 2:
- return batch_size
- elif batch_size <= 4:
- return 4
- else:
- return (batch_size + 7) // 8 * 8
- def _async_h2d(
- data: list,
- dtype: torch.dtype,
- target_device: Union[str, torch.device],
- pin_memory: bool,
- ) -> torch.Tensor:
- t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
- return t.to(device=target_device, non_blocking=True)
|