prefix.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from typing import Dict, List, Sequence, Tuple, Optional
  2. from aphrodite.common.block import BlockTable
  3. class Prefix:
  4. """Data and states associated with a prefix of prompt tokens for multiple
  5. sequence groups.
  6. NOTE: This feature is experimental and may be replaced with automatic
  7. prefix caching in the future.
  8. Args:
  9. prefix_id: The id of the prefix in the prefix pool.
  10. token_ids: The token ids of the prefix.
  11. block_size: The block size of the executed model.
  12. """
  13. def __init__(
  14. self,
  15. token_ids: Sequence[int],
  16. block_size: int,
  17. ) -> None:
  18. self.token_ids = tuple(token_ids)
  19. self.block_size = block_size
  20. self.length = len(token_ids)
  21. self.hash = hash(token_ids)
  22. assert self.length % block_size == 0
  23. self.block_table: Optional[BlockTable] = None
  24. self.computed = False
  25. @property
  26. def allocated(self) -> bool:
  27. return self.block_table is not None
  28. def get_num_blocks(self) -> int:
  29. return self.length // self.block_size
  30. def get_block_numbers(self) -> List[int]:
  31. return [block.block_number for block in self.block_table]
  32. def get_length(self) -> int:
  33. return self.length
  34. def __hash__(self) -> int:
  35. return self.hash
  36. def set_block_table(self, block_table: BlockTable) -> None:
  37. self.block_table = block_table.copy()
  38. class PrefixPool:
  39. """Manages all the prompt prefixes.
  40. NOTE: This feature is experimental and may be replaced with automatic
  41. prefix caching in the future.
  42. Args:
  43. block_size: The block size of the executed model.
  44. Attributes:
  45. prefixes: A list of all the prefixes.
  46. block_size: The block size of the executed model.
  47. """
  48. def __init__(
  49. self,
  50. block_size: int,
  51. ) -> None:
  52. # TODO: Add a capacity limit to the prefix pool.
  53. self.prefixes: Dict[int, Prefix] = {}
  54. self.block_size = block_size
  55. def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
  56. new_length = len(token_ids) // self.block_size * self.block_size
  57. return tuple(token_ids[:new_length])
  58. def add_or_get_prefix(self, token_ids: Sequence[int],
  59. lora_int_id: int) -> Optional[Prefix]:
  60. token_ids = self._truncate_token_ids(token_ids)
  61. if len(token_ids) == 0:
  62. # Prefix is empty.
  63. return None
  64. prefix = Prefix(token_ids, self.block_size)
  65. prefix_hash = hash((prefix, lora_int_id))
  66. if prefix_hash not in self.prefixes:
  67. self.prefixes[prefix_hash] = prefix
  68. return self.prefixes[prefix_hash]