1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219 |
- import contextlib
- import time
- from enum import IntEnum
- from typing import Dict, List, NamedTuple, Optional, Set, Tuple
- import numpy as np
- import torch
- import torch.nn as nn
- from loguru import logger
- from aphrodite.attention import (AttentionMetadata, AttentionMetadataPerStage,
- get_attn_backend)
- from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
- ModelConfig, ParallelConfig,
- SchedulerConfig, VisionLanguageConfig)
- from aphrodite.common.sampling_params import SamplingParams, SamplingType
- from aphrodite.common.sequence import (MultiModalData, SamplerOutput,
- SequenceData, SequenceGroupMetadata)
- from aphrodite.common.utils import (CudaMemoryProfiler, async_tensor_h2d,
- is_hip, is_pin_memory_available,
- make_tensor_with_pad, maybe_expand_dim)
- from aphrodite.distributed import (broadcast_tensor_dict,
- get_tensor_model_parallel_world_size,
- with_pynccl_for_all_reduce)
- from aphrodite.distributed.device_communicators import (custom_all_reduce,
- pynccl_utils)
- from aphrodite.lora.layers import LoRAMapping
- from aphrodite.lora.request import LoRARequest
- from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
- from aphrodite.modeling import SamplingMetadata
- from aphrodite.modeling.model_loader import get_model
- from aphrodite.modeling.sampling_metadata import PersistentMetadata
- _PAD_SLOT_ID = -1
- LORA_WARMUP_RANK = 8
- _BATCH_SIZE_ALIGNMENT = 8
- # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
- # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
- _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
- _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
- ]
- class PreparePromptMetadata(NamedTuple):
- input_tokens: List[int]
- input_positions: List[int]
- attn_metadata: Optional[AttentionMetadataPerStage]
- prompt_lens: List[int]
- subquery_lens: List[int]
- lora_index_mapping: List[int]
- lora_prompt_mapping: List[int]
- lora_requests: Set[LoRARequest]
- multi_modal_input: Optional[torch.Tensor]
- slot_mapping: List[int]
- @classmethod
- def empty(cls):
- return PreparePromptMetadata(
- input_tokens=[],
- input_positions=[],
- attn_metadata=None,
- prompt_lens=[],
- subquery_lens=[],
- lora_index_mapping=[],
- lora_prompt_mapping=[],
- lora_requests=set(),
- multi_modal_input=None,
- slot_mapping=[],
- )
- class PrepareDecodeMetadata(NamedTuple):
- input_tokens: List[int]
- input_positions: List[int]
- attn_metadata: Optional[AttentionMetadata]
- lora_index_mapping: List[int]
- lora_prompt_mapping: List[int]
- lora_requests: Set[LoRARequest]
- slot_mapping: List[int]
- @classmethod
- def empty(cls):
- return PrepareDecodeMetadata(
- input_tokens=[],
- input_positions=[],
- attn_metadata=None,
- lora_index_mapping=[],
- lora_prompt_mapping=[],
- lora_requests=set(),
- slot_mapping=[],
- )
- # How batches are constructed.
- class BatchType(IntEnum):
- # Every batch is prefill.
- PREFILL = 0
- # Every batch is decode.
- DECODE = 1
- # Batch is a mixture of prefill and decode.
- MIXED = 2
- class ModelRunner:
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- kv_cache_dtype: Optional[str] = "auto",
- is_driver_worker: bool = False,
- vision_language_config: Optional[VisionLanguageConfig] = None,
- ):
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.lora_config = lora_config
- self.load_config = load_config
- self.is_driver_worker = is_driver_worker
- # model_config can be None in tests/samplers/test_sampler.py.
- # FIXME: This is a hack to make the tests work. Refactor this.
- self.sliding_window = (model_config.get_sliding_window()
- if model_config is not None else None)
- self.device_config = (device_config
- if device_config is not None else DeviceConfig())
- self.device = self.device_config.device
- # Set after load_model.
- self.lora_manager: LRUCacheWorkerLoRAManager = None
- self.graph_runners: Dict[int, CUDAGraphRunner] = {}
- self.graph_memory_pool: Optional[Tuple[
- int, int]] = None # Set during graph capture.
- self.max_context_len_to_capture = (
- self.model_config.max_context_len_to_capture
- if self.model_config is not None else 0)
- self.pin_memory = is_pin_memory_available()
- self.kv_cache_dtype = kv_cache_dtype
- self.vision_language_config = vision_language_config
- self.attn_backend = get_attn_backend(
- self.model_config.dtype if model_config is not None else None)
- # Lazy initialization
- self.model: torch.nn.Module # Set after load_model
- self.block_size: int # Set after initial profiling.
- # When using CUDA graph, the input block tables must be padded to
- # max_context_len_to_capture. However, creating the block table in
- # Python can be expensive. To optimize this, we cache the block table
- # in numpy and only copy the actual input content at every iteration.
- # The shape of the cached block table will be
- # (max batch size to capture, max context len to capture / block size).
- self.graph_block_tables: torch.Tensor # Set after initial profiling.
- def load_model(self) -> None:
- with CudaMemoryProfiler() as m:
- self.model = get_model(
- model_config=self.model_config,
- device_config=self.device_config,
- load_config=self.load_config,
- lora_config=self.lora_config,
- vision_language_config=self.vision_language_config,
- parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config,
- )
- self.model_memory_usage = m.consumed_memory
- tp = get_tensor_model_parallel_world_size()
- logger.info(
- "Model weights loaded. Memory usage: "
- f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
- f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
- if self.lora_config:
- assert hasattr(self.model, "supported_lora_modules"
- ) and self.model.supported_lora_modules, (
- "Model does not support LoRA")
- assert hasattr(
- self.model,
- "embedding_modules"), "Model does not have embedding_modules"
- assert hasattr(self.model, "embedding_padding_modules"
- ), "Model does not have embedding_padding_modules"
- self.lora_manager = LRUCacheWorkerLoRAManager(
- self.scheduler_config.max_num_seqs,
- self.scheduler_config.max_num_batched_tokens, self.vocab_size,
- self.lora_config, self.device, self.model.embedding_modules,
- self.model.embedding_padding_modules)
- self.model = self.lora_manager.create_lora_manager(self.model)
- if self.kv_cache_dtype == "fp8" and is_hip():
- # Currently scaled KV cache is only enabled on ROCm
- if self.model_config.quantization_param_path is not None:
- if callable(getattr(self.model, "load_kv_cache_scales", None)):
- self.model.load_kv_cache_scales(
- self.model_config.quantization_param_path)
- else:
- raise RuntimeError("Using FP8 KV cache and scaling "
- "factors provided but model "
- f"{self.model.__class__} does not "
- "support loading scaling factors.")
- else:
- logger.warn("Using FP8 KV cache but no scaling factors "
- "provided. Defaulting to scaling factors of 1.0. "
- "This may lead to less accurate results!")
- elif self.model_config.quantization_param_path is not None:
- logger.warn("KV cache scaling factors provided, "
- "but the KV cache data type is not FP8. "
- "KV cache scaling factors will not be used.")
- def set_block_size(self, block_size: int) -> None:
- self.block_size = block_size
- self.graph_block_tables = np.zeros(
- (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
- dtype=np.int32)
- def get_max_block_per_batch(self) -> int:
- block_size = self.block_size
- return (self.max_context_len_to_capture + block_size - 1) // block_size
- def _prepare_prompt(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> PreparePromptMetadata:
- input_tokens: List[int] = []
- input_positions: List[int] = []
- slot_mapping: List[int] = []
- lora_index_mapping: List[int] = []
- lora_prompt_mapping: List[int] = []
- lora_requests: Set[LoRARequest] = set()
- prompt_lens: List[int] = []
- context_lens: List[int] = []
- subquery_lens: List[int] = []
- prefix_block_tables: List[List[int]] = []
- multi_modal_input_list: List[torch.Tensor] = []
- if len(seq_group_metadata_list) == 0:
- return PreparePromptMetadata.empty()
- for seq_group_metadata in seq_group_metadata_list:
- assert seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- assert len(seq_ids) == 1
- seq_id = seq_ids[0]
- computed_block_nums = seq_group_metadata.computed_block_nums
- if (self.scheduler_config is not None
- and self.scheduler_config.chunked_prefill_enabled
- and not (computed_block_nums is None
- or computed_block_nums == [])):
- raise RuntimeError(
- "chunked prefill cannot be used with prefix caching "
- "now.")
- token_chunk_size = seq_group_metadata.token_chunk_size
- seq_data = seq_group_metadata.seq_data[seq_id]
- computed_len = seq_data.get_num_computed_tokens()
- # We should use get_len here because in case of preemption
- # it contains output tokens.
- prefill_end = min(seq_data.get_len(),
- computed_len + token_chunk_size)
- prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
- prompt_len = prefill_end
- prompt_lens.append(prompt_len)
- # NOTE: This only works for oooooooxxx style attention.
- if computed_block_nums is not None and len(
- computed_block_nums) > 0 and self.sliding_window is None:
- # Prefix is not supported with sliding_window
- computed_len = len(computed_block_nums) * self.block_size
- prompt_tokens = prompt_tokens[computed_len:]
- prefix_block_tables.append(computed_block_nums)
- elif self.scheduler_config.chunked_prefill_enabled:
- if seq_group_metadata.block_tables is not None:
- # Prefill has chunked before.
- block_table = seq_group_metadata.block_tables[seq_id]
- prefix_block_tables.append(block_table)
- else:
- # The first prefill.
- prefix_block_tables.append([])
- else:
- prefix_block_tables.append([])
- # Right now, prefill start is always 0. However, this
- # assumption can be changed once chunked prefill is introduced.
- assert computed_len == 0
- # actual prompt lens
- context_lens.append(computed_len)
- subquery_lens.append(prompt_len - computed_len)
- input_tokens.extend(prompt_tokens)
- # NOTE: Here we assume that the first token in the prompt
- # is always the first token in the sequence.
- input_positions.extend(list(range(computed_len, prefill_end)))
- lora_id = seq_group_metadata.lora_int_id
- if lora_id > 0:
- lora_requests.add(seq_group_metadata.lora_request)
- lora_index_mapping += [lora_id] * (prompt_len - computed_len)
- lora_prompt_mapping.extend(
- [lora_id] *
- (prompt_len - computed_len
- if seq_group_metadata.sampling_params.prompt_logprobs else 1))
- if seq_group_metadata.multi_modal_data:
- multi_modal_input_list.append(
- seq_group_metadata.multi_modal_data.data)
- if seq_group_metadata.block_tables is None:
- # During memory profiling, the block tables are not initialized
- # yet. In this case, we just use a dummy slot mapping.
- slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
- continue
- # Compute the slot mapping.
- block_table = seq_group_metadata.block_tables[seq_id]
- # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
- # where start_idx is max(0, prompt_len - sliding_window).
- # For example, if the prompt len is 10, sliding window is 8, and
- # block size is 4, the first two tokens are masked and the slot
- # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
- start_idx = 0
- if self.sliding_window is not None:
- assert computed_len == 0, (
- "Prefix caching is currently not supported with "
- "sliding window attention")
- start_idx = max(0, prompt_len - self.sliding_window)
- for i in range(computed_len, prefill_end):
- if i < start_idx:
- slot_mapping.append(_PAD_SLOT_ID)
- continue
- block_number = block_table[i // self.block_size]
- block_offset = i % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append(slot)
- max_subquery_len = max(subquery_lens)
- max_prompt_len = max(prompt_lens)
- assert max_subquery_len > 0
- context_lens_tensor = torch.tensor(context_lens,
- dtype=torch.int,
- device=self.device)
- if multi_modal_input_list:
- assert self.vision_language_config, (
- "Multi-modal inputs are only supported by "
- "vision language models.")
- multi_modal_input = torch.cat(multi_modal_input_list,
- dim=0).to(self.device)
- else:
- multi_modal_input = None
- # Prepare prefix block tables
- max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
- block_tables = make_tensor_with_pad(
- prefix_block_tables,
- max_len=max_prompt_block_table_len,
- pad=0,
- dtype=torch.int,
- device=self.device,
- )
- # Query length can be shorter than key (i.e., prompt) when prefill
- # is chunked or prefix cached.
- subquery_lens_tensor = torch.tensor(subquery_lens,
- dtype=torch.long,
- device=self.device)
- subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
- dtype=torch.int32,
- device=self.device)
- prompt_lens_tensor = torch.tensor(prompt_lens,
- dtype=torch.long,
- device=self.device)
- seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
- dtype=torch.int32,
- device=self.device)
- torch.cumsum(subquery_lens_tensor,
- dim=0,
- dtype=subquery_start_loc.dtype,
- out=subquery_start_loc[1:])
- torch.cumsum(prompt_lens_tensor,
- dim=0,
- dtype=seq_start_loc.dtype,
- out=seq_start_loc[1:])
- attn_metadata = self.attn_backend.make_metadata(
- is_prompt=True,
- prompt_lens=prompt_lens,
- prompt_lens_tensor=prompt_lens_tensor,
- max_subquery_len=max_subquery_len,
- max_context_len=None,
- max_prompt_len=max_prompt_len,
- subquery_start_loc=subquery_start_loc,
- seq_start_loc=seq_start_loc,
- context_lens=context_lens_tensor,
- block_tables=block_tables,
- use_cuda_graph=False,
- )
- return PreparePromptMetadata(
- input_tokens=input_tokens,
- input_positions=input_positions,
- attn_metadata=attn_metadata,
- prompt_lens=prompt_lens,
- subquery_lens=subquery_lens,
- lora_index_mapping=lora_index_mapping,
- lora_prompt_mapping=lora_prompt_mapping,
- lora_requests=lora_requests,
- multi_modal_input=multi_modal_input,
- slot_mapping=slot_mapping,
- )
- def _prepare_decode(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> PrepareDecodeMetadata:
- input_tokens: List[int] = []
- input_positions: List[int] = []
- slot_mapping: List[int] = []
- context_lens: List[int] = []
- block_tables: List[List[int]] = []
- lora_index_mapping: List[int] = []
- lora_prompt_mapping: List[int] = []
- lora_requests: Set[LoRARequest] = set()
- if len(seq_group_metadata_list) == 0:
- return PrepareDecodeMetadata.empty()
- for seq_group_metadata in seq_group_metadata_list:
- assert not seq_group_metadata.is_prompt
- assert seq_group_metadata.token_chunk_size == 1
- seq_ids = list(seq_group_metadata.seq_data.keys())
- lora_id = seq_group_metadata.lora_int_id
- if lora_id > 0:
- lora_requests.add(seq_group_metadata.lora_request)
- for seq_id in seq_ids:
- seq_data = seq_group_metadata.seq_data[seq_id]
- generation_token = seq_data.get_last_token_id()
- input_tokens.append(generation_token)
- seq_len = seq_data.get_len()
- position = seq_len - 1
- input_positions.append(position)
- context_len = seq_len if self.sliding_window is None else min(
- seq_len, self.sliding_window)
- context_lens.append(context_len)
- block_table = seq_group_metadata.block_tables[seq_id]
- block_number = block_table[position // self.block_size]
- block_offset = position % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append(slot)
- lora_index_mapping.append(lora_id)
- lora_prompt_mapping.append(lora_id)
- if self.sliding_window is not None:
- sliding_window_blocks = (self.sliding_window //
- self.block_size)
- block_table = block_table[-sliding_window_blocks:]
- block_tables.append(block_table)
- # Aphrodite uses cuda graph only for decoding requests.
- # See `capture_model` API for more details.
- # For decoding requests, batch_size == input_tokens.
- batch_size = len(input_tokens)
- max_context_len = max(context_lens)
- use_captured_graph = (
- not self.model_config.enforce_eager
- and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
- and max_context_len <= self.max_context_len_to_capture)
- if use_captured_graph:
- graph_batch_size = _get_graph_batch_size(batch_size)
- assert graph_batch_size >= batch_size
- for _ in range(graph_batch_size - batch_size):
- input_tokens.append(0)
- input_positions.append(0)
- slot_mapping.append(_PAD_SLOT_ID)
- context_lens.append(1)
- block_tables.append([])
- lora_index_mapping.append(0)
- batch_size = graph_batch_size
- context_lens_tensor = torch.tensor(context_lens,
- dtype=torch.int,
- device=self.device)
- if use_captured_graph:
- # When using cuda-graph all these tensors should be
- # padded.
- assert context_lens_tensor.shape[0] == len(input_tokens)
- assert context_lens_tensor.shape[0] == len(input_positions)
- assert context_lens_tensor.shape[0] == len(slot_mapping)
- # The shape of graph_block_tables is
- # [max batch size, max context len // block size].
- input_block_tables = self.graph_block_tables[:batch_size]
- for i, block_table in enumerate(block_tables):
- if block_table:
- input_block_tables[i, :len(block_table)] = block_table
- block_tables = torch.tensor(input_block_tables, device=self.device)
- else:
- max_block_table_len = max(
- len(block_table) for block_table in block_tables)
- block_tables = make_tensor_with_pad(
- block_tables,
- max_len=max_block_table_len,
- pad=0,
- dtype=torch.int,
- device=self.device,
- )
- attn_metadata = self.attn_backend.make_metadata(
- is_prompt=False,
- prompt_lens=None,
- prompt_lens_tensor=None,
- max_subquery_len=None,
- max_context_len=max_context_len,
- max_prompt_len=None,
- subquery_start_loc=None,
- seq_start_loc=None,
- context_lens=context_lens_tensor,
- block_tables=block_tables,
- use_cuda_graph=use_captured_graph,
- )
- return PrepareDecodeMetadata(
- input_tokens=input_tokens,
- input_positions=input_positions,
- attn_metadata=attn_metadata,
- lora_index_mapping=lora_index_mapping,
- lora_prompt_mapping=lora_prompt_mapping,
- lora_requests=lora_requests,
- slot_mapping=slot_mapping,
- )
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- subquery_lens: Optional[List[int]],
- ) -> SamplingMetadata:
- seq_groups: List[Tuple[List[int], SamplingParams]] = []
- selected_token_indices: List[int] = []
- generators: List[torch.Generator] = []
- selected_token_start_idx = 0
- categorized_sample_indices: Dict[SamplingType,
- List[Tuple[int, int]]] = {
- t: []
- for t in SamplingType
- }
- categorized_sample_indices_start_idx = 0
- categorized_sampled_token_indices_start_idx = 0
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- sampling_params = seq_group_metadata.sampling_params
- seq_groups.append((seq_ids, sampling_params))
- if seq_group_metadata.is_prompt:
- assert len(seq_ids) == 1
- assert subquery_lens is not None
- subquery_len = subquery_lens[i]
- if sampling_params.prompt_logprobs is not None:
- # NOTE: prompt token positions do not need sample, skip
- categorized_sample_indices_start_idx += subquery_len - 1
- categorized_sample_indices[
- sampling_params.sampling_type].append(
- (categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx))
- categorized_sample_indices_start_idx += 1
- categorized_sampled_token_indices_start_idx += 1
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + subquery_len - 1))
- selected_token_indices.append(selected_token_start_idx +
- subquery_len - 1)
- selected_token_start_idx += subquery_len
- if sampling_params.seed is not None:
- seq_group_metadata.state.generator = torch.Generator(
- device=self.device).manual_seed(sampling_params.seed)
- else:
- num_seqs = len(seq_ids)
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + num_seqs))
- selected_token_start_idx += num_seqs
- categorized_sample_indices[
- sampling_params.sampling_type].extend(
- list(
- zip(
- range(
- categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx +
- num_seqs),
- range(
- categorized_sampled_token_indices_start_idx,
- categorized_sampled_token_indices_start_idx
- + num_seqs))))
- categorized_sample_indices_start_idx += num_seqs
- categorized_sampled_token_indices_start_idx += num_seqs
- if sampling_params.seed is not None:
- generators.append(seq_group_metadata.state.generator)
- selected_token_indices = async_tensor_h2d(selected_token_indices,
- dtype=torch.long,
- target_device=self.device,
- pin_memory=self.pin_memory)
- categorized_sample_indices = {
- t: maybe_expand_dim(
- async_tensor_h2d(seq_ids,
- dtype=torch.int,
- target_device=self.device,
- pin_memory=self.pin_memory), 2, 2)
- for t, seq_ids in categorized_sample_indices.items()
- }
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
- seq_persistence_data: Dict[int, dict] = {}
- for grp in seq_group_metadata_list:
- seq_persistence_data.update(grp.persistent_data)
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- generators=generators,
- persistent_metadata=PersistentMetadata(seq_persistence_data),
- )
- return sampling_metadata
- def prepare_input_tensors(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
- Set[LoRARequest], LoRAMapping, torch.Tensor]:
- if self.is_driver_worker:
- prefill_reqs = []
- decode_reqs = []
- for seq_group_meta in seq_group_metadata_list:
- if seq_group_meta.is_prompt:
- prefill_reqs.append(seq_group_meta)
- else:
- decode_reqs.append(seq_group_meta)
- # Prepare input tensors.
- (
- input_tokens,
- input_positions,
- prefill_attn_metadata,
- prompt_lens,
- subquery_lens,
- lora_index_mapping,
- lora_prompt_mapping,
- lora_requests,
- multi_modal_input,
- slot_mapping,
- ) = self._prepare_prompt(prefill_reqs)
- (
- decode_input_tokens,
- decode_input_positions,
- decode_attn_metadata,
- decode_lora_index_mapping,
- decode_lora_prompt_mapping,
- decode_lora_requests,
- decode_slot_mapping,
- ) = self._prepare_decode(decode_reqs)
- sampling_metadata = self._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens)
- if not self.scheduler_config.chunked_prefill_enabled:
- assert (len(prefill_reqs) and len(decode_reqs)) == 0
- num_prefills = len(prompt_lens)
- num_prefill_tokens = len(input_tokens)
- num_decode_tokens = len(decode_input_tokens)
- # Coalesce tensors. Note that attn_metadata is currently not
- # coalesced for simplicity.
- input_tokens.extend(decode_input_tokens)
- input_positions.extend(decode_input_positions)
- slot_mapping.extend(decode_slot_mapping)
- lora_index_mapping.extend(decode_lora_index_mapping)
- lora_prompt_mapping.extend(decode_lora_prompt_mapping)
- lora_requests.update(decode_lora_requests)
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.long,
- device=self.device)
- input_positions = torch.tensor(input_positions,
- dtype=torch.long,
- device=self.device)
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.long,
- device=self.device)
- if self.lora_config:
- lora_mapping = LoRAMapping(
- lora_index_mapping,
- lora_prompt_mapping,
- )
- else:
- lora_mapping = None
- # Broadcast the metadata.
- # If batch contains both prefill and decode, it sends 2 broadcasts.
- # If it only contains 1 type, it triggers a single broadcast.
- if (prefill_attn_metadata is not None
- and decode_attn_metadata is not None):
- batch_type = BatchType.MIXED
- elif prefill_attn_metadata is not None:
- batch_type = BatchType.PREFILL
- else:
- batch_type = BatchType.DECODE
- metadata_dict = {
- "input_tokens": input_tokens,
- "input_positions": input_positions,
- "selected_token_indices":
- sampling_metadata.selected_token_indices,
- "lora_requests": lora_requests,
- "lora_mapping": lora_mapping,
- "multi_modal_input": multi_modal_input,
- "num_prefill_tokens": num_prefill_tokens,
- "num_decode_tokens": num_decode_tokens,
- "slot_mapping": slot_mapping,
- "num_prefills": num_prefills,
- "batch_type": batch_type,
- }
- if prefill_attn_metadata is not None:
- metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
- else:
- assert decode_attn_metadata is not None
- metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
- broadcast_tensor_dict(metadata_dict, src=0)
- # Broadcast decode attn metadata for mixed batch type.
- # The additional broadcast costs 300us overhead on 4 A10 GPUs.
- # We can potentially reduce the overhead by coelescing tensors.
- if batch_type == BatchType.MIXED:
- assert decode_attn_metadata is not None
- metadata_dict = decode_attn_metadata.asdict_zerocopy()
- broadcast_tensor_dict(metadata_dict, src=0)
- else:
- metadata_dict = broadcast_tensor_dict(src=0)
- input_tokens = metadata_dict.pop("input_tokens")
- input_positions = metadata_dict.pop("input_positions")
- slot_mapping = metadata_dict.pop("slot_mapping")
- num_prefills = metadata_dict.pop("num_prefills")
- selected_token_indices = metadata_dict.pop(
- "selected_token_indices")
- lora_mapping = metadata_dict.pop("lora_mapping")
- lora_requests = metadata_dict.pop("lora_requests")
- multi_modal_input = metadata_dict.pop("multi_modal_input")
- num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
- num_decode_tokens = metadata_dict.pop("num_decode_tokens")
- batch_type = metadata_dict.pop("batch_type")
- # Create an attention metadata.
- prefill_attn_metadata = None
- decode_attn_metadata = None
- if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
- prefill_attn_metadata = self.attn_backend.make_metadata(
- **metadata_dict)
- else:
- decode_attn_metadata = self.attn_backend.make_metadata(
- **metadata_dict)
- sampling_metadata = SamplingMetadata(
- seq_groups=None,
- seq_data=None,
- prompt_lens=None,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=None,
- generators=None,
- perform_sampling=False,
- )
- # if it is a mixed batch, decode attn_metadata is broadcasted
- # separately.
- if batch_type == BatchType.MIXED:
- metadata_dict = broadcast_tensor_dict(src=0)
- decode_attn_metadata = self.attn_backend.make_metadata(
- **metadata_dict)
- attn_metadata = AttentionMetadata(
- num_prefills=num_prefills,
- slot_mapping=slot_mapping,
- num_prefill_tokens=num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- prefill_metadata=prefill_attn_metadata,
- decode_metadata=decode_attn_metadata,
- kv_cache_dtype=self.kv_cache_dtype,
- )
- return (input_tokens, input_positions, attn_metadata,
- sampling_metadata, lora_requests, lora_mapping,
- multi_modal_input)
- @torch.inference_mode()
- def execute_model(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- kv_caches: List[torch.Tensor],
- ) -> Optional[SamplerOutput]:
- (input_tokens, input_positions, attn_metadata, sampling_metadata,
- lora_requests, lora_mapping, multi_modal_input
- ) = self.prepare_input_tensors(seq_group_metadata_list)
- if self.lora_config:
- self.set_active_loras(lora_requests, lora_mapping)
- # Currently cuda graph is only supported by the decode phase.
- prefill_meta = attn_metadata.prefill_metadata
- decode_meta = attn_metadata.decode_metadata
- if prefill_meta is None and decode_meta.use_cuda_graph:
- graph_batch_size = input_tokens.shape[0]
- model_executable = self.graph_runners[graph_batch_size]
- else:
- model_executable = self.model
- execute_model_kwargs = {
- "input_ids": input_tokens,
- "positions": input_positions,
- "kv_caches": kv_caches,
- "attn_metadata": attn_metadata,
- }
- if self.vision_language_config:
- execute_model_kwargs.update({"image_input": multi_modal_input})
- hidden_states = model_executable(**execute_model_kwargs)
- # Compute the logits.
- logits = self.model.compute_logits(hidden_states, sampling_metadata)
- # Only perform sampling in the driver worker.
- if not sampling_metadata.perform_sampling:
- return None
- # Sample the next token.
- output = self.model.sample(
- logits=logits,
- sampling_metadata=sampling_metadata,
- )
- return output
- @torch.inference_mode()
- def profile_run(self) -> None:
- # Enable top-k sampling to reflect the accurate memory usage.
- sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
- max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
- max_num_seqs = self.scheduler_config.max_num_seqs
- # This represents the maximum number of different requests
- # that will have unique loras, an therefore the max amount of memory
- # consumption create dummy lora request copies from the lora request
- # passed in, which contains a lora from the lora warmup path.
- dummy_lora_requests = []
- dummy_lora_requests_per_seq = []
- if self.lora_config:
- for idx in range(self.lora_config.max_loras):
- lora_id = idx + 1
- dummy_lora_request = LoRARequest(
- lora_name=f"warmup_{lora_id}",
- lora_int_id=lora_id,
- lora_local_path="/not/a/real/path",
- )
- self.lora_manager.add_dummy_lora(dummy_lora_request,
- rank=LORA_WARMUP_RANK)
- dummy_lora_requests.append(dummy_lora_request)
- dummy_lora_requests_per_seq = [
- dummy_lora_requests[idx % len(dummy_lora_requests)]
- for idx in range(max_num_seqs)
- ]
- # Profile memory usage with max_num_sequences sequences and the total
- # number of tokens equal to max_num_batched_tokens.
- seqs: List[SequenceGroupMetadata] = []
- # Additional GPU memory may be needed for vision encoding, which needs
- # to be accounted for when calculating the GPU blocks for
- # Aphrodite blocker manager.
- # To exercise the worst scenario for GPU memory consumption,
- # the number of seqs (batch_size) is chosen to maximize the number
- # of images processed.
- if self.vision_language_config:
- max_num_seqs = min(
- max_num_seqs,
- int(max_num_batched_tokens /
- self.vision_language_config.image_feature_size))
- for group_id in range(max_num_seqs):
- seq_len = (max_num_batched_tokens // max_num_seqs +
- (group_id < max_num_batched_tokens % max_num_seqs))
- seq_data, fake_multi_modal_input = _prepare_fake_inputs(
- seq_len, self.vision_language_config)
- seq = SequenceGroupMetadata(
- request_id=str(group_id),
- is_prompt=True,
- seq_data={group_id: seq_data},
- sampling_params=sampling_params,
- block_tables=None,
- persistent_data={},
- lora_request=dummy_lora_requests_per_seq[group_id]
- if dummy_lora_requests_per_seq else None,
- multi_modal_data=fake_multi_modal_input,
- )
- seqs.append(seq)
- # Run the model with the dummy inputs.
- num_layers = self.model_config.get_num_layers(self.parallel_config)
- kv_caches = [None] * num_layers
- self.execute_model(seqs, kv_caches)
- torch.cuda.synchronize()
- return
- def remove_all_loras(self) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.remove_all_loras()
- def set_active_loras(self, lora_requests: Set[LoRARequest],
- lora_mapping: LoRAMapping) -> None:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- self.lora_manager.set_active_loras(lora_requests, lora_mapping)
- def add_lora(self, lora_request: LoRARequest) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.add_lora(lora_request)
- def remove_lora(self, lora_id: int) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.remove_lora(lora_id)
- def list_loras(self) -> Set[int]:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.list_loras()
- @torch.inference_mode()
- def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
- """Cuda graph capture a model.
- Note that CUDA graph's performance gain is negligible if number
- of batched tokens are larger than 200. And since CUDA graph
- requires fixed sized tensors, supporting large/variable batch
- size requires high GPU memory overhead. Thus, Aphrodite only captures
- decoding requests. Mixed batch (chunked prefill + decoding) or
- prefill requests are not captured.
- Since it is used for decoding-only, it assumes there's only 1 token
- per sequence in the batch.
- """
- # NOTE: This is a hack to ensure that the NCCL backend is never
- # deleted before the CUDA graphs.
- self.pynccl_backend = pynccl_utils.get_nccl_backend()
- assert not self.model_config.enforce_eager
- logger.info("Capturing the model for CUDA graphs. This may lead to "
- "unexpected consequences if the model is not static. To "
- "run the model in eager mode, set 'enforce_eager=True' or "
- "use '--enforce-eager' in the CLI.")
- logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
- "If you are running out of memory, consider decreasing "
- "`gpu_memory_utilization` or enforcing eager mode. "
- "You can also reduce the `max_num_seqs` as needed "
- "to decrease memory usage.")
- start_time = time.perf_counter()
- # Prepare dummy inputs. These will be reused for all batch sizes.
- max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
- input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
- input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
- slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
- slot_mapping.fill_(_PAD_SLOT_ID)
- context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
- block_tables = torch.from_numpy(self.graph_block_tables).cuda()
- graph_batch_size = _get_graph_batch_size(
- self.scheduler_config.max_num_seqs)
- batch_size_capture_list = [
- bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
- ]
- # NOTE: There are 3 backends for all-reduce: custom all-reduce
- # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
- # either custom all-reduce kernel or pynccl. When not using CUDA
- # graph, we use either custom all-reduce kernel or PyTorch NCCL.
- # We always prioritize using custom all-reduce kernel but fall back
- # to PyTorch or pynccl if it is disabled or not supported.
- with custom_all_reduce.capture():
- # NOTE: Capturing the largest batch size first may help reduce the
- # memory usage of CUDA graph.
- for batch_size in reversed(batch_size_capture_list):
- # Create dummy attn_metadata.
- decode_metadata = self.attn_backend.make_metadata(
- is_prompt=False,
- prompt_lens=None,
- prompt_lens_tensor=None,
- max_subquery_len=None,
- max_context_len=self.max_context_len_to_capture,
- max_prompt_len=None,
- subquery_start_loc=None,
- seq_start_loc=None,
- context_lens=context_lens[:batch_size],
- block_tables=block_tables[:batch_size],
- use_cuda_graph=True,
- )
- attn_metadata = AttentionMetadata(
- num_prefills=0,
- num_prefill_tokens=0,
- num_decode_tokens=batch_size,
- slot_mapping=slot_mapping[:batch_size],
- prefill_metadata=None,
- decode_metadata=decode_metadata,
- kv_cache_dtype=self.kv_cache_dtype,
- )
- if self.lora_config:
- lora_mapping = LoRAMapping(
- [0] * batch_size,
- [0] * batch_size,
- )
- self.set_active_loras(set(), lora_mapping)
- graph_runner = CUDAGraphRunner(self.model)
- graph_runner.capture(
- input_tokens[:batch_size],
- input_positions[:batch_size],
- kv_caches,
- attn_metadata,
- memory_pool=self.graph_memory_pool,
- )
- self.graph_memory_pool = graph_runner.graph.pool()
- self.graph_runners[batch_size] = graph_runner
- end_time = time.perf_counter()
- elapsed_time = end_time - start_time
- # This usually takes < 10 seconds.
- logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
- def __del__(self) -> None:
- # Delete the CUDA graphs before deleting the pynccl communicator.
- # NOTE: This is necessary because otherwise deadlocks can
- # happen.
- # FIXME: This is a bit hacky. Find a more robust solution.
- # TODO: when we get enough user feedback that pynccl is
- # more stable than cupy, we can remove this
- self.graph_runners.clear()
- self.pynccl_backend = None
- @property
- def vocab_size(self) -> int:
- return self.model_config.get_vocab_size()
- class CUDAGraphRunner:
- def __init__(self, model: nn.Module):
- self.model = model
- self.input_buffers: Dict[str, torch.Tensor] = {}
- self.output_buffers: Dict[str, torch.Tensor] = {}
- self._graph: Optional[torch.cuda.CUDAGraph] = None
- @property
- def graph(self):
- assert self._graph is not None
- return self._graph
- def capture(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- memory_pool,
- **kwargs,
- ) -> None:
- assert self._graph is None
- # Run the model once without capturing the graph.
- # This is to make sure that the captured graph does not include the
- # kernel launches for initial benchmarking (e.g., Triton autotune).
- with _maybe_pynccl():
- self.model(
- input_ids,
- positions,
- kv_caches,
- attn_metadata,
- **kwargs,
- )
- torch.cuda.synchronize()
- # Capture the graph.
- # NOTE: Python 3.8 does not support multi-line with statements.
- # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
- self._graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
- with _maybe_pynccl():
- hidden_states = self.model(
- input_ids,
- positions,
- kv_caches,
- attn_metadata,
- **kwargs,
- )
- torch.cuda.synchronize()
- # Save the input and output buffers.
- self.input_buffers = {
- "input_ids": input_ids,
- "positions": positions,
- "kv_caches": kv_caches,
- "slot_mapping": attn_metadata.slot_mapping,
- "context_lens": attn_metadata.decode_metadata.context_lens,
- "block_tables": attn_metadata.decode_metadata.block_tables,
- }
- self.output_buffers = {"hidden_states": hidden_states}
- return
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- **kwargs,
- ) -> torch.Tensor:
- # KV caches are fixed tensors, so we don't need to copy them.
- del kv_caches
- # Copy the input tensors to the input buffers.
- self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
- self.input_buffers["positions"].copy_(positions, non_blocking=True)
- self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
- non_blocking=True)
- self.input_buffers["context_lens"].copy_(
- attn_metadata.decode_metadata.context_lens, non_blocking=True)
- self.input_buffers["block_tables"].copy_(
- attn_metadata.decode_metadata.block_tables, non_blocking=True)
- # Run the graph.
- self.graph.replay()
- # Return the output tensor.
- return self.output_buffers["hidden_states"]
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
- @contextlib.contextmanager
- def _maybe_pynccl():
- if pynccl_utils.is_initialized(
- ) and not custom_all_reduce.is_initialized():
- with with_pynccl_for_all_reduce():
- yield
- else:
- yield
- def _get_graph_batch_size(batch_size: int) -> int:
- """Returns the padded batch size given actual batch size.
- Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
- 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
- """
- if batch_size <= 2:
- return batch_size
- elif batch_size <= 4:
- return 4
- else:
- return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
- _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
- def _prepare_fake_inputs(
- seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
- """Prepare fake inputs for profile run."""
- if vision_language_config:
- prompt_tokens = [
- vision_language_config.image_token_id
- ] * vision_language_config.image_feature_size + [0] * (
- seq_len - vision_language_config.image_feature_size)
- fake_image_input = MultiModalData(
- type=MultiModalData.Type.IMAGE,
- data=torch.zeros(vision_language_config.image_input_shape,
- dtype=torch.float16))
- else:
- prompt_tokens = [0] * seq_len
- fake_image_input = None
- return SequenceData(prompt_tokens), fake_image_input
|