12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- from typing import Optional, List
- import torch
- class InputMetadata:
- """Metadata for input sequences. Used in PagedAttention.
- Args:
- prompt_lens: Lengths of prompts.
- slot_mapping: The address to write the new KV to of each token.
- max_context_len: The maximum context length.
- context_lens: the length of attention context for each sequence.
- block_tables: The block tables. (Seq id -> list of physical block)
- kv_cache_dtype: Data Type to store KV cache.
- kv_quant_params: KV quant scales and zero points for int8 kv cache.
- """
- def __init__(
- self,
- is_prompt: bool,
- slot_mapping: torch.Tensor,
- prompt_lens: Optional[torch.Tensor],
- max_seq_len: Optional[int],
- start_loc: Optional[torch.Tensor],
- max_context_len: Optional[int],
- context_lens: Optional[torch.Tensor],
- block_tables: Optional[torch.Tensor],
- use_cuda_graph: bool,
- kv_cache_dtype: str,
- kv_quant_params: List[List[float]],
- ) -> None:
- self.is_prompt = is_prompt
- self.prompt_lens = prompt_lens
- self.max_seq_len = max_seq_len
- self.start_loc = start_loc
- self.max_context_len = max_context_len
- self.slot_mapping = slot_mapping
- self.context_lens = context_lens
- self.block_tables = block_tables
- self.use_cuda_graph = use_cuda_graph
- self.kv_cache_dtype = kv_cache_dtype
- self.kv_quant_params = kv_quant_params
- # Set during the execution of the first attention op.
- # FIXME: This is a hack.
- self.attn_bias = None
- def __repr__(self) -> str:
- return ("InputMetadata("
- f"is_prompt={self.is_prompt}, "
- f"max_context_len={self.max_context_len}, "
- f"slot_mapping={self.slot_mapping}, "
- f"context_lens={self.context_lens}, "
- f"block_tables={self.block_tables}, "
- f"use_cuda_graph={self.use_cuda_graph}, "
- f"kv_cache_dtype={self.kv_cache_dtype}, "
- f"kv_quant_params={self.kv_quant_params})")
|