1
0

metadata.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from typing import Dict, List, Tuple, Optional
  2. import torch
  3. from xformers.ops import AttentionBias
  4. from aphrodite.common.sampling_params import SamplingParams
  5. from aphrodite.common.sequence import SequenceData
  6. class PersistentMetadata:
  7. def __init__(self):
  8. self._metadata: dict[int, dict] = {}
  9. def get(self, seq_id: int) -> dict:
  10. return self._metadata.get(seq_id, {})
  11. class OutputMetadata(PersistentMetadata):
  12. def add(self, seq_id: int, key, val) -> None:
  13. if seq_id not in self._metadata:
  14. self._metadata[seq_id] = {}
  15. self._metadata[seq_id][key] = val
  16. class InputMetadata:
  17. """Metadata for input sequences. Used for PagedAttention.
  18. Args:
  19. seq_groups: List of (seq_ids, sampling_params).
  20. seq_data: Seq_id -> SequenceData.
  21. prompt_lens: Lengths of prompts.
  22. slot_mapping: The address to write the new KV to of each token.
  23. context_lens: the length of attention context for each generation token.
  24. max_context_len: The maximum context length.
  25. block_tables: The block tables. (Seq id -> list of physical block)
  26. """
  27. def __init__(
  28. self,
  29. seq_groups: List[Tuple[List[int], SamplingParams]],
  30. seq_data: Dict[int, SequenceData],
  31. prompt_lens: List[int],
  32. slot_mapping: torch.Tensor,
  33. context_lens: torch.Tensor,
  34. max_context_len: int,
  35. block_tables: torch.Tensor,
  36. sliding_window: Optional[int] = None,
  37. persistent_data: Optional[PersistentMetadata] = None,
  38. ) -> None:
  39. self.seq_groups = seq_groups
  40. self.seq_data = seq_data
  41. self.prompt_lens = prompt_lens
  42. self.slot_mapping = slot_mapping
  43. self.context_lens = context_lens
  44. self.max_context_len = max_context_len
  45. self.block_tables = block_tables
  46. self.persistent_data = persistent_data or PersistentMetadata()
  47. self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
  48. self.to_cache = None
  49. if sliding_window is not None:
  50. # We need to keep the positions of sliding windows within
  51. # the key/value tables, this is helpful to know which
  52. # elements we need to cache and where.
  53. to_cache, start_idx = [], 0
  54. for prompt_len in self.prompt_lens:
  55. to_cache.extend(
  56. range(
  57. start_idx + max(0, prompt_len - sliding_window),
  58. start_idx + prompt_len,
  59. ))
  60. start_idx += self.max_prompt_len
  61. to_cache.extend(range(start_idx, slot_mapping.shape[0]))
  62. self.to_cache = torch.tensor(to_cache,
  63. dtype=torch.int32,
  64. device=self.slot_mapping.device)
  65. self.num_prompts = len(prompt_lens)
  66. self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
  67. self.num_generation_tokens = context_lens.shape[0]
  68. self.num_valid_tokens = slot_mapping.shape[0]
  69. if block_tables.numel() > 0:
  70. self.max_num_blocks_per_seq = block_tables.shape[1]
  71. else:
  72. self.max_num_blocks_per_seq = 0
  73. assert block_tables.shape[0] == self.num_generation_tokens
  74. # Set during the execution of the first attention op.
  75. self.attn_bias: Optional[AttentionBias] = None
  76. def __repr__(self) -> str:
  77. # Print only useful metadata.
  78. return (f'InputMetadata('
  79. f'num_prompt_tokens={self.num_prompt_tokens}, '
  80. f'num_prompts={self.num_prompts}, '
  81. f'prompt_lens={self.prompt_lens}, '
  82. f'num_generation_tokens={self.num_generation_tokens}, '
  83. f'context_lens={self.context_lens}, '
  84. f'max_context_len={self.max_context_len}, '
  85. f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
  86. f'block_tables={self.block_tables}, '
  87. f'slot_mapping={self.slot_mapping}, '
  88. f'persistent_data={self.persistent_data})')