from typing import Optional, List import torch class InputMetadata: """Metadata for input sequences. Used in PagedAttention. Args: prompt_lens: Lengths of prompts. slot_mapping: The address to write the new KV to of each token. max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) kv_cache_dtype: Data Type to store KV cache. kv_quant_params: KV quant scales and zero points for int8 kv cache. """ def __init__( self, is_prompt: bool, slot_mapping: torch.Tensor, prompt_lens: Optional[torch.Tensor], max_seq_len: Optional[int], start_loc: Optional[torch.Tensor], max_context_len: Optional[int], context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, kv_quant_params: List[List[float]], ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens self.max_seq_len = max_seq_len self.start_loc = start_loc self.max_context_len = max_context_len self.slot_mapping = slot_mapping self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype self.kv_quant_params = kv_quant_params # Set during the execution of the first attention op. # FIXME: This is a hack. self.attn_bias = None def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " f"prompt_lens={self.prompt_lens}, " f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " f"kv_cache_dtype={self.kv_cache_dtype}, " f"kv_quant_params={self.kv_quant_params})")