12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226 |
- import contextlib
- import time
- from enum import IntEnum
- from typing import Dict, List, NamedTuple, Optional, Set, Tuple
- import numpy as np
- import torch
- import torch.nn as nn
- from loguru import logger
- from aphrodite.attention import (
- AttentionMetadata,
- AttentionMetadataPerStage,
- get_attn_backend,
- )
- from aphrodite.common.config import (
- DeviceConfig,
- LoRAConfig,
- ModelConfig,
- ParallelConfig,
- SchedulerConfig,
- VisionLanguageConfig,
- )
- from aphrodite.common.logger import get_loading_progress_bar
- from aphrodite.common.sampling_params import SamplingParams, SamplingType
- from aphrodite.common.sequence import (
- MultiModalData,
- SamplerOutput,
- SequenceData,
- SequenceGroupMetadata,
- )
- from aphrodite.common.utils import (
- CudaMemoryProfiler,
- async_tensor_h2d,
- is_hip,
- is_pin_memory_available,
- make_tensor_with_pad,
- maybe_expand_dim,
- )
- from aphrodite.distributed import (
- broadcast_tensor_dict,
- get_tensor_model_parallel_world_size,
- with_pynccl_for_all_reduce,
- )
- from aphrodite.distributed.device_communicators import (
- custom_all_reduce,
- pynccl_utils,
- )
- from aphrodite.lora.layers import LoRAMapping
- from aphrodite.lora.request import LoRARequest
- from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
- from aphrodite.modeling import SamplingMetadata
- from aphrodite.modeling.loader import get_model
- from aphrodite.modeling.sampling_metadata import PersistentMetadata
- _PAD_SLOT_ID = -1
- LORA_WARMUP_RANK = 8
- _BATCH_SIZE_ALIGNMENT = 8
- # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
- # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
- _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
- _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
- ]
- class PreparePromptMetadata(NamedTuple):
- input_tokens: List[int]
- input_positions: List[int]
- attn_metadata: Optional[AttentionMetadataPerStage]
- prompt_lens: List[int]
- subquery_lens: List[int]
- lora_index_mapping: List[int]
- lora_prompt_mapping: List[int]
- lora_requests: Set[LoRARequest]
- multi_modal_input: Optional[torch.Tensor]
- slot_mapping: List[int]
- @classmethod
- def empty(cls):
- return PreparePromptMetadata(
- input_tokens=[],
- input_positions=[],
- attn_metadata=None,
- prompt_lens=[],
- subquery_lens=[],
- lora_index_mapping=[],
- lora_prompt_mapping=[],
- lora_requests=set(),
- multi_modal_input=None,
- slot_mapping=[],
- )
- class PrepareDecodeMetadata(NamedTuple):
- input_tokens: List[int]
- input_positions: List[int]
- attn_metadata: Optional[AttentionMetadata]
- lora_index_mapping: List[int]
- lora_prompt_mapping: List[int]
- lora_requests: Set[LoRARequest]
- slot_mapping: List[int]
- @classmethod
- def empty(cls):
- return PrepareDecodeMetadata(
- input_tokens=[],
- input_positions=[],
- attn_metadata=None,
- lora_index_mapping=[],
- lora_prompt_mapping=[],
- lora_requests=set(),
- slot_mapping=[],
- )
- # How batches are constructed.
- class BatchType(IntEnum):
- # Every batch is prefill.
- PREFILL = 0
- # Every batch is decode.
- DECODE = 1
- # Batch is a mixture of prefill and decode.
- MIXED = 2
- class ModelRunner:
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- kv_cache_dtype: Optional[str] = "auto",
- is_driver_worker: bool = False,
- vision_language_config: Optional[VisionLanguageConfig] = None,
- ):
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.lora_config = lora_config
- self.is_driver_worker = is_driver_worker
- # model_config can be None in tests/samplers/test_sampler.py.
- # FIXME: This is a hack to make the tests work. Refactor this.
- self.sliding_window = (model_config.get_sliding_window()
- if model_config is not None else None)
- self.device_config = (device_config
- if device_config is not None else DeviceConfig())
- self.device = self.device_config.device
- self.model = None
- self.block_size = None # Set after initial profiling.
- self.lora_manager = None
- self.graph_runners: Dict[int, CUDAGraphRunner] = {}
- self.graph_memory_pool = None # Set during graph capture.
- self.max_context_len_to_capture = (
- self.model_config.max_context_len_to_capture
- if self.model_config is not None else 0)
- # When using CUDA graph, the input block tables must be padded to
- # max_context_len_to_capture. However, creating the block table in
- # Python can be expensive. To optimize this, we cache the block table
- # in numpy and only copy the actual input content at every iteration.
- # The shape of the cached block table will be
- # (max batch size to capture, max context len to capture / block size).
- self.graph_block_tables = None # Set after initial profiling.
- self.pin_memory = is_pin_memory_available()
- self.kv_cache_dtype = kv_cache_dtype
- self.vision_language_config = vision_language_config
- self.attn_backend = get_attn_backend(
- self.model_config.dtype if model_config is not None else None)
- def load_model(self) -> None:
- with CudaMemoryProfiler() as m:
- self.model = get_model(
- self.model_config,
- self.device_config,
- lora_config=self.lora_config,
- vision_language_config=self.vision_language_config,
- parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config)
- self.model_memory_usage = m.consumed_memory
- tp = get_tensor_model_parallel_world_size()
- logger.info(
- "Model weights loaded. Memory usage: "
- f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
- f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
- if self.lora_config:
- assert hasattr(self.model, "supported_lora_modules"
- ) and self.model.supported_lora_modules, (
- "Model does not support LoRA")
- assert hasattr(
- self.model,
- "embedding_modules"), "Model does not have embedding_modules"
- assert hasattr(self.model, "embedding_padding_modules"
- ), "Model does not have embedding_padding_modules"
- self.lora_manager = LRUCacheWorkerLoRAManager(
- self.scheduler_config.max_num_seqs,
- self.scheduler_config.max_num_batched_tokens, self.vocab_size,
- self.lora_config, self.device, self.model.embedding_modules,
- self.model.embedding_padding_modules)
- self.model = self.lora_manager.create_lora_manager(self.model)
- if self.kv_cache_dtype == "fp8" and is_hip():
- # Currently scaled KV cache is only enabled on ROCm
- if self.model_config.quantization_param_path is not None:
- if callable(getattr(self.model, "load_kv_cache_scales", None)):
- self.model.load_kv_cache_scales(
- self.model_config.quantization_param_path)
- else:
- raise RuntimeError("Using FP8 KV cache and scaling "
- "factors provided but model "
- f"{self.model.__class__} does not "
- "support loading scaling factors.")
- else:
- logger.warning(
- "Using FP8 KV cache but no scaling factors "
- "provided. Defaulting to scaling factors of 1.0. "
- "This may lead to less accurate results!")
- elif self.model_config.quantization_param_path is not None:
- logger.warning("KV cache scaling factors provided, "
- "but the KV cache data type is not FP8. "
- "KV cache scaling factors will not be used.")
- def set_block_size(self, block_size: int) -> None:
- self.block_size = block_size
- self.graph_block_tables = np.zeros(
- (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
- dtype=np.int32)
- def get_max_block_per_batch(self) -> int:
- block_size = self.block_size
- return (self.max_context_len_to_capture + block_size - 1) // block_size
- def _prepare_prompt(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> PreparePromptMetadata:
- input_tokens: List[int] = []
- input_positions: List[int] = []
- slot_mapping: List[int] = []
- lora_index_mapping: List[int] = []
- lora_prompt_mapping: List[int] = []
- lora_requests: Set[LoRARequest] = set()
- prompt_lens: List[int] = []
- context_lens: List[int] = []
- subquery_lens: List[int] = []
- prefix_block_tables: List[List[int]] = []
- multi_modal_input_list: List[torch.Tensor] = []
- if len(seq_group_metadata_list) == 0:
- return PreparePromptMetadata.empty()
- for seq_group_metadata in seq_group_metadata_list:
- assert seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- assert len(seq_ids) == 1
- seq_id = seq_ids[0]
- computed_block_nums = seq_group_metadata.computed_block_nums
- if (self.scheduler_config is not None
- and self.scheduler_config.chunked_prefill_enabled
- and not (computed_block_nums is None
- or computed_block_nums == [])):
- raise RuntimeError(
- "chunked prefill cannot be used with prefix caching "
- "now.")
- token_chunk_size = seq_group_metadata.token_chunk_size
- seq_data = seq_group_metadata.seq_data[seq_id]
- computed_len = seq_data.get_num_computed_tokens()
- # We should use get_len here because in case of preemption
- # it contains output tokens.
- prefill_end = min(seq_data.get_len(),
- computed_len + token_chunk_size)
- prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
- prompt_len = prefill_end
- prompt_lens.append(prompt_len)
- # NOTE: This only works for oooooooxxx style attention.
- if computed_block_nums is not None and len(
- computed_block_nums) > 0 and self.sliding_window is None:
- # Prefix is not supported with sliding_window
- computed_len = len(computed_block_nums) * self.block_size
- prompt_tokens = prompt_tokens[computed_len:]
- prefix_block_tables.append(computed_block_nums)
- elif self.scheduler_config.chunked_prefill_enabled:
- if seq_group_metadata.block_tables is not None:
- # Prefill has chunked before.
- block_table = seq_group_metadata.block_tables[seq_id]
- prefix_block_tables.append(block_table)
- else:
- # The first prefill.
- prefix_block_tables.append([])
- else:
- prefix_block_tables.append([])
- # Right now, prefill start is always 0. However, this
- # assumption can be changed once chunked prefill is introduced.
- assert computed_len == 0
- # actual prompt lens
- context_lens.append(computed_len)
- subquery_lens.append(prompt_len - computed_len)
- input_tokens.extend(prompt_tokens)
- # NOTE: Here we assume that the first token in the prompt
- # is always the first token in the sequence.
- input_positions.extend(list(range(computed_len, prefill_end)))
- lora_id = seq_group_metadata.lora_int_id
- if lora_id > 0:
- lora_requests.add(seq_group_metadata.lora_request)
- lora_index_mapping += [lora_id] * (prompt_len - computed_len)
- lora_prompt_mapping.extend(
- [lora_id] *
- (prompt_len - computed_len
- if seq_group_metadata.sampling_params.prompt_logprobs else 1))
- if seq_group_metadata.multi_modal_data:
- multi_modal_input_list.append(
- seq_group_metadata.multi_modal_data.data)
- if seq_group_metadata.block_tables is None:
- # During memory profiling, the block tables are not initialized
- # yet. In this case, we just use a dummy slot mapping.
- slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
- continue
- # Compute the slot mapping.
- block_table = seq_group_metadata.block_tables[seq_id]
- # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
- # where start_idx is max(0, prompt_len - sliding_window).
- # For example, if the prompt len is 10, sliding window is 8, and
- # block size is 4, the first two tokens are masked and the slot
- # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
- start_idx = 0
- if self.sliding_window is not None:
- assert computed_len == 0, (
- "Prefix caching is currently not supported with "
- "sliding window attention")
- start_idx = max(0, prompt_len - self.sliding_window)
- for i in range(computed_len, prefill_end):
- if i < start_idx:
- slot_mapping.append(_PAD_SLOT_ID)
- continue
- block_number = block_table[i // self.block_size]
- block_offset = i % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append(slot)
- max_subquery_len = max(subquery_lens)
- max_prompt_len = max(prompt_lens)
- assert max_subquery_len > 0
- context_lens_tensor = torch.tensor(context_lens,
- dtype=torch.int,
- device=self.device)
- if multi_modal_input_list:
- assert self.vision_language_config, (
- "Multi-modal inputs are only supported by "
- "vision language models.")
- multi_modal_input = torch.cat(multi_modal_input_list,
- dim=0).to(self.device)
- else:
- multi_modal_input = None
- # Prepare prefix block tables
- max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
- block_tables = make_tensor_with_pad(
- prefix_block_tables,
- max_len=max_prompt_block_table_len,
- pad=0,
- dtype=torch.int,
- device=self.device,
- )
- # Query length can be shorter than key (i.e., prompt) when prefill
- # is chunked or prefix cached.
- subquery_lens_tensor = torch.tensor(subquery_lens,
- dtype=torch.long,
- device=self.device)
- subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
- dtype=torch.int32,
- device=self.device)
- prompt_lens_tensor = torch.tensor(prompt_lens,
- dtype=torch.long,
- device=self.device)
- seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
- dtype=torch.int32,
- device=self.device)
- torch.cumsum(subquery_lens_tensor,
- dim=0,
- dtype=subquery_start_loc.dtype,
- out=subquery_start_loc[1:])
- torch.cumsum(prompt_lens_tensor,
- dim=0,
- dtype=seq_start_loc.dtype,
- out=seq_start_loc[1:])
- attn_metadata = self.attn_backend.make_metadata(
- is_prompt=True,
- prompt_lens=prompt_lens,
- prompt_lens_tensor=prompt_lens_tensor,
- max_subquery_len=max_subquery_len,
- max_context_len=None,
- max_prompt_len=max_prompt_len,
- subquery_start_loc=subquery_start_loc,
- seq_start_loc=seq_start_loc,
- context_lens=context_lens_tensor,
- block_tables=block_tables,
- use_cuda_graph=False,
- )
- return PreparePromptMetadata(
- input_tokens=input_tokens,
- input_positions=input_positions,
- attn_metadata=attn_metadata,
- prompt_lens=prompt_lens,
- subquery_lens=subquery_lens,
- lora_index_mapping=lora_index_mapping,
- lora_prompt_mapping=lora_prompt_mapping,
- lora_requests=lora_requests,
- multi_modal_input=multi_modal_input,
- slot_mapping=slot_mapping,
- )
- def _prepare_decode(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> PrepareDecodeMetadata:
- input_tokens: List[int] = []
- input_positions: List[int] = []
- slot_mapping: List[int] = []
- context_lens: List[int] = []
- block_tables: List[List[int]] = []
- lora_index_mapping: List[int] = []
- lora_prompt_mapping: List[int] = []
- lora_requests: Set[LoRARequest] = set()
- if len(seq_group_metadata_list) == 0:
- return PrepareDecodeMetadata.empty()
- for seq_group_metadata in seq_group_metadata_list:
- assert not seq_group_metadata.is_prompt
- assert seq_group_metadata.token_chunk_size == 1
- seq_ids = list(seq_group_metadata.seq_data.keys())
- lora_id = seq_group_metadata.lora_int_id
- if lora_id > 0:
- lora_requests.add(seq_group_metadata.lora_request)
- for seq_id in seq_ids:
- seq_data = seq_group_metadata.seq_data[seq_id]
- generation_token = seq_data.get_last_token_id()
- input_tokens.append(generation_token)
- seq_len = seq_data.get_len()
- position = seq_len - 1
- input_positions.append(position)
- context_len = seq_len if self.sliding_window is None else min(
- seq_len, self.sliding_window)
- context_lens.append(context_len)
- block_table = seq_group_metadata.block_tables[seq_id]
- block_number = block_table[position // self.block_size]
- block_offset = position % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append(slot)
- lora_index_mapping.append(lora_id)
- lora_prompt_mapping.append(lora_id)
- if self.sliding_window is not None:
- sliding_window_blocks = (self.sliding_window //
- self.block_size)
- block_table = block_table[-sliding_window_blocks:]
- block_tables.append(block_table)
- # Aphrodite uses CUDA graph only for decoding requests.
- # See `capture_model` API for more details.
- # For decoding requests, batch_size == input_tokens.
- batch_size = len(input_tokens)
- max_context_len = max(context_lens)
- use_captured_graph = (
- not self.model_config.enforce_eager
- and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
- and max_context_len <= self.max_context_len_to_capture)
- if use_captured_graph:
- graph_batch_size = _get_graph_batch_size(batch_size)
- assert graph_batch_size >= batch_size
- for _ in range(graph_batch_size - batch_size):
- input_tokens.append(0)
- input_positions.append(0)
- slot_mapping.append(_PAD_SLOT_ID)
- context_lens.append(1)
- block_tables.append([])
- lora_index_mapping.append(0)
- batch_size = graph_batch_size
- context_lens = torch.tensor(context_lens,
- dtype=torch.int,
- device=self.device)
- if use_captured_graph:
- # When using cuda-graph all these tensors should be
- # padded.
- assert context_lens.shape[0] == len(input_tokens)
- assert context_lens.shape[0] == len(input_positions)
- assert context_lens.shape[0] == len(slot_mapping)
- # The shape of graph_block_tables is
- # [max batch size, max context len // block size].
- input_block_tables = self.graph_block_tables[:batch_size]
- for i, block_table in enumerate(block_tables):
- if block_table:
- input_block_tables[i, :len(block_table)] = block_table
- block_tables = torch.tensor(input_block_tables, device=self.device)
- else:
- max_block_table_len = max(
- len(block_table) for block_table in block_tables)
- block_tables = make_tensor_with_pad(
- block_tables,
- max_len=max_block_table_len,
- pad=0,
- dtype=torch.int,
- device=self.device,
- )
- attn_metadata = self.attn_backend.make_metadata(
- is_prompt=False,
- prompt_lens=None,
- prompt_lens_tensor=None,
- max_subquery_len=None,
- max_context_len=max_context_len,
- max_prompt_len=None,
- subquery_start_loc=None,
- seq_start_loc=None,
- context_lens=context_lens,
- block_tables=block_tables,
- use_cuda_graph=use_captured_graph,
- )
- return PrepareDecodeMetadata(
- input_tokens=input_tokens,
- input_positions=input_positions,
- attn_metadata=attn_metadata,
- lora_index_mapping=lora_index_mapping,
- lora_prompt_mapping=lora_prompt_mapping,
- lora_requests=lora_requests,
- slot_mapping=slot_mapping,
- )
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- subquery_lens: Optional[List[int]],
- ) -> SamplingMetadata:
- seq_groups: List[Tuple[List[int], SamplingParams]] = []
- selected_token_indices: List[int] = []
- generators: List[torch.Generator] = []
- selected_token_start_idx = 0
- categorized_sample_indices = {t: [] for t in SamplingType}
- categorized_sample_indices_start_idx = 0
- categorized_sampled_token_indices_start_idx = 0
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- sampling_params = seq_group_metadata.sampling_params
- seq_groups.append((seq_ids, sampling_params))
- if seq_group_metadata.is_prompt:
- assert len(seq_ids) == 1
- assert subquery_lens is not None
- subquery_len = subquery_lens[i]
- if sampling_params.prompt_logprobs is not None:
- # NOTE: prompt token positions do not need sample, skip
- categorized_sample_indices_start_idx += subquery_len - 1
- categorized_sample_indices[
- sampling_params.sampling_type].append([
- categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx
- ])
- categorized_sample_indices_start_idx += 1
- categorized_sampled_token_indices_start_idx += 1
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + subquery_len - 1))
- selected_token_indices.append(selected_token_start_idx +
- subquery_len - 1)
- selected_token_start_idx += subquery_len
- if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
- assert sampling_params.seed is not None
- seq_group_metadata.state.generator = torch.Generator(
- device=self.device).manual_seed(sampling_params.seed)
- else:
- num_seqs = len(seq_ids)
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + num_seqs))
- selected_token_start_idx += num_seqs
- categorized_sample_indices[
- sampling_params.sampling_type].extend(
- zip(
- range(
- categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx +
- num_seqs),
- range(
- categorized_sampled_token_indices_start_idx,
- categorized_sampled_token_indices_start_idx +
- num_seqs)))
- categorized_sample_indices_start_idx += num_seqs
- categorized_sampled_token_indices_start_idx += num_seqs
- if seq_group_metadata.state.generator is not None:
- generators.append(seq_group_metadata.state.generator)
- selected_token_indices = async_tensor_h2d(selected_token_indices,
- dtype=torch.long,
- target_device=self.device,
- pin_memory=self.pin_memory)
- categorized_sample_indices = {
- t: maybe_expand_dim(
- async_tensor_h2d(seq_ids,
- dtype=torch.int,
- target_device=self.device,
- pin_memory=self.pin_memory), 2, 2)
- for t, seq_ids in categorized_sample_indices.items()
- }
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
- seq_persistence_data: Dict[int, dict] = {}
- for grp in seq_group_metadata_list:
- seq_persistence_data.update(grp.persistent_data)
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- generators=generators,
- persistent_metadata=PersistentMetadata(seq_persistence_data),
- )
- return sampling_metadata
- def prepare_input_tensors(
- self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
- Set[int], LoRAMapping, torch.Tensor]:
- if self.is_driver_worker:
- prefill_reqs = []
- decode_reqs = []
- for seq_group_meta in seq_group_metadata_list:
- if seq_group_meta.is_prompt:
- prefill_reqs.append(seq_group_meta)
- else:
- decode_reqs.append(seq_group_meta)
- # Prepare input tensors.
- (
- input_tokens,
- input_positions,
- prefill_attn_metadata,
- prompt_lens,
- subquery_lens,
- lora_index_mapping,
- lora_prompt_mapping,
- lora_requests,
- multi_modal_input,
- slot_mapping,
- ) = self._prepare_prompt(prefill_reqs)
- (
- decode_input_tokens,
- decode_input_positions,
- decode_attn_metadata,
- decode_lora_index_mapping,
- decode_lora_prompt_mapping,
- decode_lora_requests,
- decode_slot_mapping,
- ) = self._prepare_decode(decode_reqs)
- sampling_metadata = self._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens)
- if not self.scheduler_config.chunked_prefill_enabled:
- assert (len(prefill_reqs) and len(decode_reqs)) == 0
- num_prefills = len(prompt_lens)
- num_prefill_tokens = len(input_tokens)
- num_decode_tokens = len(decode_input_tokens)
- # Coalesce tensors. Note that attn_metadata is currently not
- # coalesced for simplicity.
- input_tokens.extend(decode_input_tokens)
- input_positions.extend(decode_input_positions)
- slot_mapping.extend(decode_slot_mapping)
- lora_index_mapping.extend(decode_lora_index_mapping)
- lora_prompt_mapping.extend(decode_lora_prompt_mapping)
- lora_requests.update(decode_lora_requests)
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.long,
- device=self.device)
- input_positions = torch.tensor(input_positions,
- dtype=torch.long,
- device=self.device)
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.long,
- device=self.device)
- if self.lora_config:
- lora_mapping = LoRAMapping(
- lora_index_mapping,
- lora_prompt_mapping,
- )
- else:
- lora_mapping = None
- # Broadcast the metadata.
- # If batch contains both prefill and decode, it sends 2 broadcasts.
- # If it only contains 1 type, it triggers a single broadcast.
- if (prefill_attn_metadata is not None
- and decode_attn_metadata is not None):
- batch_type = BatchType.MIXED
- elif prefill_attn_metadata is not None:
- batch_type = BatchType.PREFILL
- else:
- batch_type = BatchType.DECODE
- metadata_dict = {
- "input_tokens": input_tokens,
- "input_positions": input_positions,
- "selected_token_indices":
- sampling_metadata.selected_token_indices,
- "lora_requests": lora_requests,
- "lora_mapping": lora_mapping,
- "multi_modal_input": multi_modal_input,
- "num_prefill_tokens": num_prefill_tokens,
- "num_decode_tokens": num_decode_tokens,
- "slot_mapping": slot_mapping,
- "num_prefills": num_prefills,
- "batch_type": batch_type,
- }
- if prefill_attn_metadata is not None:
- metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
- else:
- metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
- broadcast_tensor_dict(metadata_dict, src=0)
- # Broadcast decode attn metadata for mixed batch type.
- # The additional broadcast costs 300us overhead on 4 A10 GPUs.
- # We can potentially reduce the overhead by coelescing tensors.
- if batch_type == BatchType.MIXED:
- assert decode_attn_metadata is not None
- metadata_dict = decode_attn_metadata.asdict_zerocopy()
- broadcast_tensor_dict(metadata_dict, src=0)
- else:
- metadata_dict = broadcast_tensor_dict(src=0)
- input_tokens = metadata_dict.pop("input_tokens")
- input_positions = metadata_dict.pop("input_positions")
- slot_mapping = metadata_dict.pop("slot_mapping")
- num_prefills = metadata_dict.pop("num_prefills")
- selected_token_indices = metadata_dict.pop(
- "selected_token_indices")
- lora_mapping = metadata_dict.pop("lora_mapping")
- lora_requests = metadata_dict.pop("lora_requests")
- multi_modal_input = metadata_dict.pop("multi_modal_input")
- num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
- num_decode_tokens = metadata_dict.pop("num_decode_tokens")
- batch_type = metadata_dict.pop("batch_type")
- # Create an attention metadata.
- prefill_attn_metadata = None
- decode_attn_metadata = None
- if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
- prefill_attn_metadata = self.attn_backend.make_metadata(
- **metadata_dict)
- else:
- decode_attn_metadata = self.attn_backend.make_metadata(
- **metadata_dict)
- sampling_metadata = SamplingMetadata(
- seq_groups=None,
- seq_data=None,
- prompt_lens=None,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=None,
- generators=None,
- perform_sampling=False,
- )
- # if it is a mixed batch, decode attn_metadata is broadcasted
- # separately.
- if batch_type == BatchType.MIXED:
- metadata_dict = broadcast_tensor_dict(src=0)
- decode_attn_metadata = self.attn_backend.make_metadata(
- **metadata_dict)
- attn_metadata = AttentionMetadata(
- num_prefills=num_prefills,
- slot_mapping=slot_mapping,
- num_prefill_tokens=num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- prefill_metadata=prefill_attn_metadata,
- decode_metadata=decode_attn_metadata,
- kv_cache_dtype=self.kv_cache_dtype,
- )
- return (input_tokens, input_positions, attn_metadata,
- sampling_metadata, lora_requests, lora_mapping,
- multi_modal_input)
- @torch.inference_mode()
- def execute_model(
- self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- kv_caches: List[torch.Tensor],
- ) -> Optional[SamplerOutput]:
- (input_tokens, input_positions, attn_metadata, sampling_metadata,
- lora_requests, lora_mapping, multi_modal_input
- ) = self.prepare_input_tensors(seq_group_metadata_list)
- if self.lora_config:
- self.set_active_loras(lora_requests, lora_mapping)
- # Currently cuda graph is only supported by the decode phase.
- prefill_meta = attn_metadata.prefill_metadata
- decode_meta = attn_metadata.decode_metadata
- if prefill_meta is None and decode_meta.use_cuda_graph:
- graph_batch_size = input_tokens.shape[0]
- model_executable = self.graph_runners[graph_batch_size]
- else:
- model_executable = self.model
- execute_model_kwargs = {
- "input_ids": input_tokens,
- "positions": input_positions,
- "kv_caches": kv_caches,
- "attn_metadata": attn_metadata,
- }
- if self.vision_language_config:
- execute_model_kwargs.update({"image_input": multi_modal_input})
- hidden_states = model_executable(**execute_model_kwargs)
- # Compute the logits.
- logits = self.model.compute_logits(hidden_states, sampling_metadata)
- # Only perform sampling in the driver worker.
- if not sampling_metadata.perform_sampling:
- return None
- # Sample the next token.
- output = self.model.sample(
- logits=logits,
- sampling_metadata=sampling_metadata,
- )
- return output
- @torch.inference_mode()
- def profile_run(self) -> None:
- # Enable top-k sampling to reflect the accurate memory usage.
- sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
- max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
- max_num_seqs = self.scheduler_config.max_num_seqs
- # This represents the maximum number of different requests
- # that will have unique loras, an therefore the max amount of memory
- # consumption create dummy lora request copies from the lora request
- # passed in, which contains a lora from the lora warmup path.
- dummy_lora_requests = []
- dummy_lora_requests_per_seq = []
- if self.lora_config:
- for idx in range(self.lora_config.max_loras):
- lora_id = idx + 1
- dummy_lora_request = LoRARequest(
- lora_name=f"warmup_{lora_id}",
- lora_int_id=lora_id,
- lora_local_path="/not/a/real/path",
- )
- self.lora_manager.add_dummy_lora(dummy_lora_request,
- rank=LORA_WARMUP_RANK)
- dummy_lora_requests.append(dummy_lora_request)
- dummy_lora_requests_per_seq = [
- dummy_lora_requests[idx % len(dummy_lora_requests)]
- for idx in range(max_num_seqs)
- ]
- # Profile memory usage with max_num_sequences sequences and the total
- # number of tokens equal to max_num_batched_tokens.
- seqs: List[SequenceGroupMetadata] = []
- # Additional GPU memory may be needed for vision encoding, which needs
- # to be accounted for when calculating the GPU blocks for
- # Aphrodite blocker manager.
- # To exercise the worst scenario for GPU memory consumption,
- # the number of seqs (batch_size) is chosen to maximize the number
- # of images processed.
- if self.vision_language_config:
- max_num_seqs = min(
- max_num_seqs,
- int(max_num_batched_tokens /
- self.vision_language_config.image_feature_size))
- for group_id in range(max_num_seqs):
- seq_len = (max_num_batched_tokens // max_num_seqs +
- (group_id < max_num_batched_tokens % max_num_seqs))
- seq_data, fake_multi_modal_input = _prepare_fake_inputs(
- seq_len, self.vision_language_config)
- seq = SequenceGroupMetadata(
- request_id=str(group_id),
- is_prompt=True,
- seq_data={group_id: seq_data},
- sampling_params=sampling_params,
- block_tables=None,
- persistent_data={},
- lora_request=dummy_lora_requests_per_seq[group_id]
- if dummy_lora_requests_per_seq else None,
- multi_modal_data=fake_multi_modal_input,
- )
- seqs.append(seq)
- # Run the model with the dummy inputs.
- num_layers = self.model_config.get_num_layers(self.parallel_config)
- kv_caches = [None] * num_layers
- self.execute_model(seqs, kv_caches)
- torch.cuda.synchronize()
- return
- def remove_all_loras(self) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.remove_all_loras()
- def set_active_loras(self, lora_requests: List[LoRARequest],
- lora_mapping: LoRAMapping) -> None:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- self.lora_manager.set_active_loras(lora_requests, lora_mapping)
- def add_lora(self, lora_request: LoRARequest) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.add_lora(lora_request)
- def remove_lora(self, lora_id: int) -> bool:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.remove_lora(lora_id)
- def list_loras(self) -> Set[int]:
- if not self.lora_manager:
- raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.list_loras()
- @torch.inference_mode()
- def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
- """Cuda graph capture a model.
- Note that CUDA graph's performance gain is negligible if number
- of batched tokens are larger than 200. And since CUDA graph
- requires fixed sized tensors, supporting large/variable batch
- size requires high GPU memory overhead. Thus, Aphrodite only captures
- decoding requests. Mixed batch (chunked prefill + decoding) or
- prefill requests are not captured.
- Since it is used for decoding-only, it assumes there's only 1 token
- per sequence in the batch.
- """
- # NOTE: This is a hack to ensure that the NCCL backend is never
- # deleted before the CUDA graphs.
- self.pynccl_backend = pynccl_utils.get_nccl_backend()
- assert not self.model_config.enforce_eager
- logger.info("Capturing the model for CUDA graphs. This may lead to "
- "unexpected consequences if the model is not static. To "
- "run the model in eager mode, set 'enforce_eager=True' or "
- "use '--enforce-eager' in the CLI.")
- logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
- "per GPU. If you are running out of memory, consider "
- "decreasing `gpu_memory_utilization` or enforcing "
- "eager mode.")
- start_time = time.perf_counter()
- # Prepare dummy inputs. These will be reused for all batch sizes.
- max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
- input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
- input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
- slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
- slot_mapping.fill_(_PAD_SLOT_ID)
- context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
- block_tables = torch.from_numpy(self.graph_block_tables).cuda()
- graph_batch_size = _get_graph_batch_size(
- self.scheduler_config.max_num_seqs)
- batch_size_capture_list = [
- bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
- ]
- # NOTE: There are 3 backends for all-reduce: custom all-reduce
- # kernel, PyNCCL, and PyTorch NCCL. When using CUDA graph, we use
- # either custom all-reduce kernel or PyNCCL. When not using CUDA
- # graph, we use either custom all-reduce kernel or PyTorch NCCL.
- # We always prioritize using custom all-reduce kernel but fall back
- # to PyTorch or PyNCCL if it is disabled or not supported.
- # Initialize a new progress bar
- progress = get_loading_progress_bar()
- task = progress.add_task("[cyan]Capturing graph...",
- total=len(batch_size_capture_list))
- with progress, custom_all_reduce.capture():
- for batch_size in reversed(batch_size_capture_list):
- # Create dummy attn_metadata.
- decode_metadata = self.attn_backend.make_metadata(
- is_prompt=False,
- prompt_lens=None,
- prompt_lens_tensor=None,
- max_subquery_len=None,
- max_context_len=self.max_context_len_to_capture,
- max_prompt_len=None,
- subquery_start_loc=None,
- seq_start_loc=None,
- context_lens=context_lens[:batch_size],
- block_tables=block_tables[:batch_size],
- use_cuda_graph=True,
- )
- attn_metadata = AttentionMetadata(
- num_prefills=0,
- num_prefill_tokens=0,
- num_decode_tokens=batch_size,
- slot_mapping=slot_mapping[:batch_size],
- prefill_metadata=None,
- decode_metadata=decode_metadata,
- kv_cache_dtype=self.kv_cache_dtype,
- )
- if self.lora_config:
- lora_mapping = LoRAMapping(
- [0] * batch_size,
- [0] * batch_size,
- )
- self.set_active_loras(set(), lora_mapping)
- graph_runner = CUDAGraphRunner(self.model)
- graph_runner.capture(
- input_tokens[:batch_size],
- input_positions[:batch_size],
- kv_caches,
- attn_metadata,
- memory_pool=self.graph_memory_pool,
- )
- self.graph_memory_pool = graph_runner.graph.pool()
- self.graph_runners[batch_size] = graph_runner
- # Update the progress bar
- progress.update(task, advance=1)
- end_time = time.perf_counter()
- elapsed_time = end_time - start_time
- # This usually takes < 10 seconds.
- logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
- def __del__(self) -> None:
- # Delete the CUDA graphs before deleting the pynccl communicator.
- # NOTE: This is necessary because otherwise deadlocks can
- # happen.
- # FIXME: This is a bit hacky. Find a more robust solution.
- # TODO: when we get enough user feedback that pynccl is
- # more stable than cupy, we can remove this
- self.graph_runners.clear()
- self.pynccl_backend = None
- @property
- def vocab_size(self) -> int:
- return self.model_config.get_vocab_size()
- class CUDAGraphRunner:
- def __init__(self, model: nn.Module):
- self.model = model
- self.graph = None
- self.input_buffers: Dict[str, torch.Tensor] = {}
- self.output_buffers: Dict[str, torch.Tensor] = {}
- def capture(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- memory_pool,
- **kwargs,
- ) -> None:
- assert self.graph is None
- # Run the model once without capturing the graph.
- # This is to make sure that the captured graph does not include the
- # kernel launches for initial benchmarking (e.g., Triton autotune).
- with _maybe_pynccl():
- self.model(
- input_ids,
- positions,
- kv_caches,
- attn_metadata,
- **kwargs,
- )
- torch.cuda.synchronize()
- # Capture the graph.
- # NOTE: Python 3.8 does not support multi-line with statements.
- # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
- self.graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
- with _maybe_pynccl():
- hidden_states = self.model(
- input_ids,
- positions,
- kv_caches,
- attn_metadata,
- **kwargs,
- )
- torch.cuda.synchronize()
- # Save the input and output buffers.
- self.input_buffers = {
- "input_ids": input_ids,
- "positions": positions,
- "kv_caches": kv_caches,
- "slot_mapping": attn_metadata.slot_mapping,
- "context_lens": attn_metadata.decode_metadata.context_lens,
- "block_tables": attn_metadata.decode_metadata.block_tables,
- }
- self.output_buffers = {"hidden_states": hidden_states}
- return
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- **kwargs,
- ) -> torch.Tensor:
- # KV caches are fixed tensors, so we don't need to copy them.
- del kv_caches
- # Copy the input tensors to the input buffers.
- self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
- self.input_buffers["positions"].copy_(positions, non_blocking=True)
- self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
- non_blocking=True)
- self.input_buffers["context_lens"].copy_(
- attn_metadata.decode_metadata.context_lens, non_blocking=True)
- self.input_buffers["block_tables"].copy_(
- attn_metadata.decode_metadata.block_tables, non_blocking=True)
- # Run the graph.
- self.graph.replay()
- # Return the output tensor.
- return self.output_buffers["hidden_states"]
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
- @contextlib.contextmanager
- def _maybe_pynccl():
- if pynccl_utils.is_initialized(
- ) and not custom_all_reduce.is_initialized():
- with with_pynccl_for_all_reduce():
- yield
- else:
- yield
- def _get_graph_batch_size(batch_size: int) -> int:
- """Returns the padded batch size given actual batch size.
- Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
- 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
- """
- if batch_size <= 2:
- return batch_size
- elif batch_size <= 4:
- return 4
- else:
- return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
- _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
- def _prepare_fake_inputs(
- seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
- """Prepare fake inputs for profile run."""
- if vision_language_config:
- prompt_tokens = [
- vision_language_config.image_token_id
- ] * vision_language_config.image_feature_size + [0] * (
- seq_len - vision_language_config.image_feature_size)
- fake_image_input = MultiModalData(
- type=MultiModalData.Type.IMAGE,
- data=torch.zeros(vision_language_config.image_input_shape,
- dtype=torch.float16))
- else:
- prompt_tokens = [0] * seq_len
- fake_image_input = None
- return SequenceData(prompt_tokens), fake_image_input
|