metadata.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from typing import Optional
  2. import torch
  3. class InputMetadata:
  4. """Metadata for input sequences. Used in PagedAttention.
  5. Args:
  6. prompt_lens: Lengths of prompts.
  7. slot_mapping: The address to write the new KV to of each token.
  8. max_context_len: The maximum context length.
  9. context_lens: the length of attention context for each sequence.
  10. block_tables: The block tables. (Seq id -> list of physical block)
  11. kv_cache_dtype: Data Type to store KV cache.
  12. """
  13. def __init__(
  14. self,
  15. is_prompt: bool,
  16. slot_mapping: torch.Tensor,
  17. prompt_lens: Optional[torch.Tensor],
  18. max_seq_len: Optional[int],
  19. start_loc: Optional[torch.Tensor],
  20. max_context_len: Optional[int],
  21. context_lens: Optional[torch.Tensor],
  22. block_tables: Optional[torch.Tensor],
  23. use_cuda_graph: bool,
  24. kv_cache_dtype: str,
  25. ) -> None:
  26. self.is_prompt = is_prompt
  27. self.prompt_lens = prompt_lens
  28. self.max_seq_len = max_seq_len
  29. self.start_loc = start_loc
  30. self.max_context_len = max_context_len
  31. self.slot_mapping = slot_mapping
  32. self.context_lens = context_lens
  33. self.block_tables = block_tables
  34. self.use_cuda_graph = use_cuda_graph
  35. self.kv_cache_dtype = kv_cache_dtype
  36. # Set during the execution of the first attention op.
  37. # FIXME: This is a hack.
  38. self.attn_bias = None
  39. def __repr__(self) -> str:
  40. return ("InputMetadata("
  41. f"is_prompt={self.is_prompt}, "
  42. f"max_context_len={self.max_context_len}, "
  43. f"slot_mapping={self.slot_mapping}, "
  44. f"context_lens={self.context_lens}, "
  45. f"block_tables={self.block_tables}, "
  46. f"use_cuda_graph={self.use_cuda_graph}, "
  47. f"kv_cache_dtype={self.kv_cache_dtype})")