metadata.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from typing import Optional
  2. import torch
  3. class InputMetadata:
  4. """Metadata for input sequences. Used in PagedAttention.
  5. Args:
  6. prompt_lens: Lengths of prompts.
  7. slot_mapping: The address to write the new KV to of each token.
  8. max_context_len: The maximum context length.
  9. context_lens: the length of attention context for each sequence.
  10. block_tables: The block tables. (Seq id -> list of physical block)
  11. """
  12. def __init__(
  13. self,
  14. is_prompt: bool,
  15. slot_mapping: torch.Tensor,
  16. prompt_lens: Optional[torch.Tensor],
  17. max_seq_len: Optional[int],
  18. start_loc: Optional[torch.Tensor],
  19. max_context_len: Optional[int],
  20. context_lens: Optional[torch.Tensor],
  21. block_tables: Optional[torch.Tensor],
  22. use_cuda_graph: bool,
  23. ) -> None:
  24. self.is_prompt = is_prompt
  25. self.prompt_lens = prompt_lens
  26. self.max_seq_len = max_seq_len
  27. self.start_loc = start_loc
  28. self.max_context_len = max_context_len
  29. self.slot_mapping = slot_mapping
  30. self.context_lens = context_lens
  31. self.block_tables = block_tables
  32. self.use_cuda_graph = use_cuda_graph
  33. # Set during the execution of the first attention op.
  34. # FIXME: This is a hack.
  35. self.attn_bias = None
  36. def __repr__(self) -> str:
  37. return ("InputMetadata("
  38. f"is_prompt={self.is_prompt}, "
  39. f"max_context_len={self.max_context_len}, "
  40. f"slot_mapping={self.slot_mapping}, "
  41. f"context_lens={self.context_lens}, "
  42. f"block_tables={self.block_tables}, "
  43. f"use_cuda_graph={self.use_cuda_graph})")