block.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. """Token blocks."""
  2. from typing import List
  3. from aphrodite.common.utils import Device
  4. _BLANK_TOKEN_ID = -1
  5. class LogicalTokenBlock:
  6. """A block that stores a contiguous chunk of tokens from left to right.
  7. Logical blocks are used to represent the states of the corresponding
  8. physical blocks in the KV cache.
  9. """
  10. def __init__(
  11. self,
  12. block_number: int,
  13. block_size: int,
  14. ) -> None:
  15. self.block_number = block_number
  16. self.block_size = block_size
  17. self.token_ids = [_BLANK_TOKEN_ID] * block_size
  18. self.num_tokens = 0
  19. def is_empty(self) -> bool:
  20. return self.num_tokens == 0
  21. def get_num_empty_slots(self) -> int:
  22. return self.block_size - self.num_tokens
  23. def is_full(self) -> bool:
  24. return self.num_tokens == self.block_size
  25. def append_tokens(self, token_ids: List[int]) -> None:
  26. assert len(token_ids) <= self.get_num_empty_slots()
  27. curr_idx = self.num_tokens
  28. self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
  29. self.num_tokens += len(token_ids)
  30. def get_token_ids(self) -> List[int]:
  31. return self.token_ids[:self.num_tokens]
  32. def get_last_token_id(self) -> int:
  33. assert self.num_tokens > 0
  34. return self.token_ids[self.num_tokens - 1]
  35. class PhysicalTokenBlock:
  36. """Represents the state of a block in the KV cache."""
  37. def __init__(
  38. self,
  39. device: Device,
  40. block_number: int,
  41. block_size: int,
  42. ) -> None:
  43. self.device = device
  44. self.block_number = block_number
  45. self.block_size = block_size
  46. self.ref_count = 0
  47. def __repr__(self) -> str:
  48. return (f'PhysicalTokenBlock(device={self.device}, '
  49. f'block_number={self.block_number}, '
  50. f'ref_count={self.ref_count})')
  51. # Mapping: logical block number -> physical block.
  52. BlockTable = List[PhysicalTokenBlock]