block.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. """Token blocks."""
  2. from typing import List, Optional
  3. from aphrodite.common.utils import Device
  4. DEFAULT_LAST_ACCESSED_TIME = -1
  5. class PhysicalTokenBlock:
  6. """Represents the state of a block in the KV cache."""
  7. def __init__(
  8. self,
  9. device: Device,
  10. block_number: int,
  11. block_size: int,
  12. block_hash: int,
  13. num_hashed_tokens: int,
  14. ) -> None:
  15. self.device = device
  16. self.block_number = block_number
  17. self.block_size = block_size
  18. self.block_hash = block_hash
  19. self.num_hashed_tokens = num_hashed_tokens
  20. self.ref_count = 0
  21. self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
  22. self.computed = False
  23. def __repr__(self) -> str:
  24. return (f'PhysicalTokenBlock(device={self.device}, '
  25. f'block_number={self.block_number}, '
  26. f'num_hashed_tokens={self.num_hashed_tokens}, '
  27. f'ref_count={self.ref_count}, '
  28. f'last_accessed={self.last_accessed}, '
  29. f'computed={self.computed})')
  30. class BlockTable:
  31. """Holds a list of blocks with caching of their associated block_ids
  32. """
  33. def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
  34. self._blocks: List[PhysicalTokenBlock] = []
  35. self._block_ids: List[int] = []
  36. if blocks is not None:
  37. for block in blocks:
  38. self.append(block)
  39. def append(self, block: PhysicalTokenBlock):
  40. self._blocks.append(block)
  41. self._block_ids.append(block.block_number)
  42. def __len__(self) -> int:
  43. return len(self._blocks)
  44. def __getitem__(self, key):
  45. return self._blocks[key]
  46. def __setitem__(self, key, value):
  47. if isinstance(key, slice):
  48. blocks = value
  49. self._blocks[key] = blocks
  50. self._block_ids[key] = [b.block_number for b in blocks]
  51. else:
  52. block = value
  53. self._blocks[key] = block
  54. self._block_ids[key] = block.block_number
  55. def reset(self):
  56. self._blocks = []
  57. self._block_ids = []
  58. def copy(self) -> "BlockTable":
  59. return BlockTable(self._blocks)
  60. def list(self) -> List[PhysicalTokenBlock]:
  61. return self._blocks
  62. def ids(self) -> List[int]:
  63. return self._block_ids