from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union import numpy as np import torch from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " "with encoder/decoder models.") PAD_SLOT_ID = -1 # Switch to numpy implementation of compute_slot_mapping # if we have at least this many elements. Could be tuned further. _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 if TYPE_CHECKING: from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder def is_block_tables_empty(block_tables: Union[None, Dict]): """ Check if block_tables is None or a dictionary with all None values. """ if block_tables is None: return True if isinstance(block_tables, dict) and all( value is None for value in block_tables.values()): return True return False def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, context_len: int, sliding_window: int, use_v2_block_manager: bool): """ Compute the start index of slot mapping. """ start_idx = 0 if is_prompt and sliding_window is not None: assert use_v2_block_manager or context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention in V1 block manager") # When prefill, we use it to not write slots to kv cache # to save memory. start_idx = max(0, query_len - sliding_window) return start_idx def _compute_slot_mapping_python(slot_mapping: List[int], block_table: List[int], range_start: int, range_end: int, block_size: int): for i in range(range_start, range_end): block_number = block_table[i // block_size] block_offset = i % block_size slot = block_number * block_size + block_offset slot_mapping.append(slot) def _compute_slot_mapping_numpy(slot_mapping: List[int], block_table: List[int], range_start: int, range_end: int, block_size: int): block_table_array = np.array(block_table) idx = np.arange(range_start, range_end) block_offset = idx % block_size idx //= block_size seq_slot_mapping_array = block_table_array[idx] seq_slot_mapping_array *= block_size seq_slot_mapping_array += block_offset slot_mapping.extend(seq_slot_mapping_array) def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], seq_id: int, seq_len: int, context_len: int, start_idx: int, block_size: int, block_tables: Dict[int, List[int]]): """ Compute slot mapping. """ if is_profile_run: # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy # slot mapping. # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([PAD_SLOT_ID] * seq_len) return # Mask the [0, start_idx) tokens of the prompt with # PAD_SLOT_ID, where start_idx is max(0, seq_len - # sliding_window). For example, if the prompt len is 10, # sliding window is 8, and block size is 4, the first two # tokens are masked and the slot mapping will be # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. padding_mask_len = max(0, start_idx - context_len) slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) range_start = max(start_idx, context_len) range_end = seq_len numel = range_end - range_start block_table = block_tables[seq_id] # numpy implementation will be faster than python if we have # many elements, otherwise it will be slower. if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: _compute_slot_mapping_python(slot_mapping, block_table, range_start, range_end, block_size) else: _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, range_end, block_size) TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): _metadata_cls: Type[TAttentionMetadata] def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size self.use_v2_block_manager = ( input_builder.scheduler_config.use_v2_block_manager) def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables computed_block_nums = inter_data.computed_block_nums for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: assert query_len == 1, ( "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) # Compute block table. # TODO: Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] if inter_data.prefix_cache_hit: block_table = computed_block_nums elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): block_table = block_tables[seq_id][-curr_sliding_window_block:] self.block_tables.append(block_table) # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx( is_prompt, query_len, context_len, self.sliding_window, self.use_v2_block_manager) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table block_tables = torch.from_numpy(input_block_tables).to( device, non_blocking=True) else: block_tables = make_tensor_with_pad( self.block_tables, pad=0, dtype=torch.int, device=device, ) assert max_query_len > 0, "query_lens: {}".format(query_lens) assert device is not None context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) torch.cumsum(query_lens_tensor, dim=0, dtype=query_start_loc.dtype, out=query_start_loc[1:]) return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, )