block.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. """Token blocks."""
  2. import weakref
  3. from collections import defaultdict
  4. from typing import Dict, List
  5. from aphrodite.common.utils import Device
  6. _BLANK_TOKEN_ID = -1
  7. DEFAULT_LAST_ACCESSED_TIME = -1
  8. TokensBlock = List[int]
  9. class BlockPool:
  10. """A pool of logical blocks.
  11. When requests come, we create a lot of logical blocks;
  12. when requests are done, we destroy a lot of logical blocks.
  13. It turns out that creating and destroying logical blocks can be expensive,
  14. especially for the `token_ids` field, which is a list of integers.
  15. To avoid this overhead, we use a pool to manage the logical blocks.
  16. When an old request is done and a new request comes, we can reuse the
  17. logical blocks from the old request to feed the new request.
  18. """
  19. def __init__(self) -> None:
  20. # block size to list of token blocks
  21. self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)
  22. def alloc_block(self, block_size: int) -> TokensBlock:
  23. if block_size in self.pool and self.pool[block_size]:
  24. return self.pool[block_size].pop()
  25. return [_BLANK_TOKEN_ID] * block_size
  26. def del_block(self, block: TokensBlock) -> None:
  27. self.pool[len(block)].append(block)
  28. _BLOCK_POOL = BlockPool()
  29. class LogicalTokenBlock:
  30. """A block that stores a contiguous chunk of tokens from left to right.
  31. Logical blocks are used to represent the states of the corresponding
  32. physical blocks in the KV cache.
  33. """
  34. def __init__(
  35. self,
  36. block_number: int,
  37. block_size: int,
  38. ) -> None:
  39. self.block_number = block_number
  40. self.block_size = block_size
  41. self.token_ids = _BLOCK_POOL.alloc_block(block_size)
  42. # this finalizer is used to return the block to the pool when the object is deleted # noqa
  43. # NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
  44. # i.e. `self.token_ids` may be deleted before `self`, and we lose
  45. # the opportunity to return the block to the pool
  46. self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
  47. self.token_ids)
  48. self.num_tokens = 0
  49. def is_empty(self) -> bool:
  50. return self.num_tokens == 0
  51. def get_num_empty_slots(self) -> int:
  52. return self.block_size - self.num_tokens
  53. def is_full(self) -> bool:
  54. return self.num_tokens == self.block_size
  55. def append_tokens(self, token_ids: List[int]) -> None:
  56. assert len(token_ids) <= self.get_num_empty_slots()
  57. curr_idx = self.num_tokens
  58. self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
  59. self.num_tokens += len(token_ids)
  60. def get_token_ids(self) -> List[int]:
  61. return self.token_ids[:self.num_tokens]
  62. def get_last_token_id(self) -> int:
  63. assert self.num_tokens > 0
  64. return self.token_ids[self.num_tokens - 1]
  65. class PhysicalTokenBlock:
  66. """Represents the state of a block in the KV cache."""
  67. def __init__(
  68. self,
  69. device: Device,
  70. block_number: int,
  71. block_size: int,
  72. block_hash: int,
  73. num_hashed_tokens: int,
  74. ) -> None:
  75. self.device = device
  76. self.block_number = block_number
  77. self.block_size = block_size
  78. self.block_hash = block_hash
  79. self.num_hashed_tokens = num_hashed_tokens
  80. self.ref_count = 0
  81. self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
  82. self.computed = False
  83. def __repr__(self) -> str:
  84. return (f'PhysicalTokenBlock(device={self.device}, '
  85. f'block_number={self.block_number}, '
  86. f'num_hashed_tokens={self.num_hashed_tokens}, '
  87. f'ref_count={self.ref_count}, '
  88. f'last_accessed={self.last_accessed}, '
  89. f'computed={self.computed})')
  90. # Mapping: logical block number -> physical block.
  91. BlockTable = List[PhysicalTokenBlock]