1
0

block.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """Token blocks."""
  2. from typing import List
  3. from aphrodite.common.utils import Device
  4. _BLANK_TOKEN_ID = -1
  5. DEFAULT_LAST_ACCESSED_TIME = -1
  6. class LogicalTokenBlock:
  7. """A block that stores a contiguous chunk of tokens from left to right.
  8. Logical blocks are used to represent the states of the corresponding
  9. physical blocks in the KV cache.
  10. """
  11. def __init__(
  12. self,
  13. block_number: int,
  14. block_size: int,
  15. ) -> None:
  16. self.block_number = block_number
  17. self.block_size = block_size
  18. self.token_ids = [_BLANK_TOKEN_ID] * block_size
  19. self.num_tokens = 0
  20. def is_empty(self) -> bool:
  21. return self.num_tokens == 0
  22. def get_num_empty_slots(self) -> int:
  23. return self.block_size - self.num_tokens
  24. def is_full(self) -> bool:
  25. return self.num_tokens == self.block_size
  26. def append_tokens(self, token_ids: List[int]) -> None:
  27. assert len(token_ids) <= self.get_num_empty_slots()
  28. curr_idx = self.num_tokens
  29. self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
  30. self.num_tokens += len(token_ids)
  31. def get_token_ids(self) -> List[int]:
  32. return self.token_ids[:self.num_tokens]
  33. def get_last_token_id(self) -> int:
  34. assert self.num_tokens > 0
  35. return self.token_ids[self.num_tokens - 1]
  36. class PhysicalTokenBlock:
  37. """Represents the state of a block in the KV cache."""
  38. def __init__(
  39. self,
  40. device: Device,
  41. block_number: int,
  42. block_size: int,
  43. block_hash: int,
  44. num_hashed_tokens: int,
  45. ) -> None:
  46. self.device = device
  47. self.block_number = block_number
  48. self.block_size = block_size
  49. self.block_hash = block_hash
  50. self.num_hashed_tokens = num_hashed_tokens
  51. self.ref_count = 0
  52. self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
  53. self.computed = False
  54. def __repr__(self) -> str:
  55. return (f'PhysicalTokenBlock(device={self.device}, '
  56. f'block_number={self.block_number}, '
  57. f'num_hashed_tokens={self.num_hashed_tokens}, '
  58. f'ref_count={self.ref_count}, '
  59. f'last_accessed={self.last_accessed}, '
  60. f'computed={self.computed})')
  61. # Mapping: logical block number -> physical block.
  62. BlockTable = List[PhysicalTokenBlock]