1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- from typing import Dict, List, Tuple
- import torch
- from xformers.ops import AttentionBias
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.sequence import SequenceData
- class InputMetadata:
- """Metadata for input sequences. Used for PagedAttention.
- Args:
- seq_groups: List of (seq_ids, sampling_params).
- seq_data: Seq_id -> SequenceData.
- prompt_lens: Lengths of prompts.
- slot_mapping: The address to write the new KV to of each token.
- context_lens: the length of attention context for each generation token.
- max_context_len: The maximum context length.
- block_tables: The block tables. (Seq id -> list of physical block)
- """
- def __init__(
- self,
- seq_groups: List[Tuple[List[int], SamplingParams]],
- seq_data: Dict[int, SequenceData],
- prompt_lens: List[int],
- slot_mapping: torch.Tensor,
- context_lens: torch.Tensor,
- max_context_len: int,
- block_tables: torch.Tensor,
- ) -> None:
- self.seq_groups = seq_groups
- self.seq_data = seq_data
- self.prompt_lens = prompt_lens
- self.slot_mapping = slot_mapping
- self.context_lens = context_lens
- self.max_context_len = max_context_len
- self.block_tables = block_tables
- self.num_prompts = len(prompt_lens)
- self.num_prompt_tokens = sum(prompt_lens)
- self.num_generation_tokens = context_lens.shape[0]
- self.num_valid_tokens = slot_mapping.shape[0]
- if block_tables.numel() > 0:
- self.max_num_blocks_per_seq = block_tables.shape[1]
- else:
- self.max_num_blocks_per_seq = 0
- assert block_tables.shape[0] == self.num_generation_tokens
- assert context_lens.shape[0] == self.num_generation_tokens
- # Set during the execution of the first attention op.
- self.attn_bias: List[AttentionBias] = []
- def __repr__(self) -> str:
- # Print only useful metadata.
- return (f'InputMetadata('
- f'num_valid_tokens={self.num_valid_tokens}, '
- f'num_prompt_tokens={self.num_prompt_tokens}, '
- f'num_prompts={self.num_prompts}, '
- f'prompt_lens={self.prompt_lens}, '
- f'num_generation_tokens={self.num_generation_tokens}, '
- f'context_lens={self.context_lens}, '
- f'max_context_len={self.max_context_len}), '
- f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
- f'block_tables={self.block_tables}), '
- f'slot_mapping={self.slot_mapping}')
|