metadata.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from typing import Optional, List
  2. import torch
  3. class InputMetadata:
  4. """Metadata for input sequences. Used in PagedAttention.
  5. Args:
  6. prompt_lens: Lengths of prompts.
  7. slot_mapping: The address to write the new KV to of each token.
  8. max_context_len: The maximum context length.
  9. context_lens: the length of attention context for each sequence.
  10. block_tables: The block tables. (Seq id -> list of physical block)
  11. kv_cache_dtype: Data Type to store KV cache.
  12. kv_quant_params: KV quant scales and zero points for int8 kv cache.
  13. """
  14. def __init__(
  15. self,
  16. is_prompt: bool,
  17. slot_mapping: torch.Tensor,
  18. prompt_lens: Optional[torch.Tensor],
  19. max_seq_len: Optional[int],
  20. start_loc: Optional[torch.Tensor],
  21. max_context_len: Optional[int],
  22. context_lens: Optional[torch.Tensor],
  23. block_tables: Optional[torch.Tensor],
  24. use_cuda_graph: bool,
  25. kv_cache_dtype: str,
  26. kv_quant_params: List[List[float]],
  27. ) -> None:
  28. self.is_prompt = is_prompt
  29. self.prompt_lens = prompt_lens
  30. self.max_seq_len = max_seq_len
  31. self.start_loc = start_loc
  32. self.max_context_len = max_context_len
  33. self.slot_mapping = slot_mapping
  34. self.context_lens = context_lens
  35. self.block_tables = block_tables
  36. self.use_cuda_graph = use_cuda_graph
  37. self.kv_cache_dtype = kv_cache_dtype
  38. self.kv_quant_params = kv_quant_params
  39. # Set during the execution of the first attention op.
  40. # FIXME: This is a hack.
  41. self.attn_bias = None
  42. def __repr__(self) -> str:
  43. return ("InputMetadata("
  44. f"is_prompt={self.is_prompt}, "
  45. f"prompt_lens={self.prompt_lens}, "
  46. f"max_context_len={self.max_context_len}, "
  47. f"slot_mapping={self.slot_mapping}, "
  48. f"context_lens={self.context_lens}, "
  49. f"block_tables={self.block_tables}, "
  50. f"use_cuda_graph={self.use_cuda_graph}, "
  51. f"kv_cache_dtype={self.kv_cache_dtype}, "
  52. f"kv_quant_params={self.kv_quant_params})")