1
0

block_table.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import math
  2. from typing import List, Optional
  3. from aphrodite.common.utils import Device, cdiv, chunk_list
  4. from aphrodite.processing.block.common import BlockList
  5. from aphrodite.processing.block.interfaces import (Block,
  6. DeviceAwareBlockAllocator)
  7. class BlockTable:
  8. """A class to manage blocks for a specific sequence.
  9. The BlockTable maps a sequence of tokens to a list of blocks, where each
  10. block represents a contiguous memory allocation for a portion of the
  11. sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
  12. responsible for allocating and freeing memory for the blocks.
  13. Args:
  14. block_size (int): The maximum number of tokens that can be stored in a
  15. single block.
  16. block_allocator (DeviceAwareBlockAllocator): The block allocator used to
  17. manage memory for the blocks.
  18. _blocks (Optional[List[Block]], optional): An optional list of existing
  19. blocks to initialize the BlockTable with. If not provided, an empty
  20. BlockTable is created.
  21. max_block_sliding_window (Optional[int], optional): The number of
  22. blocks to keep around for each sequance. If None, all blocks
  23. are kept (eg., when sliding window is not used).
  24. It should at least fit the sliding window size of the model.
  25. Attributes:
  26. _block_size (int): The maximum number of tokens that can be stored in a
  27. single block.
  28. _allocator (DeviceAwareBlockAllocator): The block allocator used to
  29. manage memory for the blocks.
  30. _blocks (Optional[List[Block]]): The list of blocks managed by this
  31. BlockTable.
  32. _num_full_slots (int): The number of tokens currently stored in the
  33. blocks.
  34. """
  35. def __init__(
  36. self,
  37. block_size: int,
  38. block_allocator: DeviceAwareBlockAllocator,
  39. _blocks: Optional[List[Block]] = None,
  40. max_block_sliding_window: Optional[int] = None,
  41. ):
  42. self._block_size = block_size
  43. self._allocator = block_allocator
  44. if _blocks is None:
  45. _blocks = []
  46. self._blocks: BlockList = BlockList(_blocks)
  47. self._max_block_sliding_window = max_block_sliding_window
  48. self._num_full_slots = self._get_num_token_ids()
  49. @staticmethod
  50. def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
  51. """Calculates the minimum number of blocks required to store a given
  52. sequence of token IDs.
  53. This assumes worst-case scenario, where every block requires a new
  54. allocation (e.g. ignoring prefix caching).
  55. Args:
  56. token_ids (List[int]): The sequence of token IDs to be stored.
  57. block_size (int): The maximum number of tokens that can be stored in
  58. a single block.
  59. Returns:
  60. int: The minimum number of blocks required to store the given
  61. sequence of token IDs.
  62. """
  63. return cdiv(len(token_ids), block_size)
  64. def allocate(self,
  65. token_ids: List[int],
  66. device: Device = Device.GPU) -> None:
  67. """Allocates memory blocks for storing the given sequence of token IDs.
  68. This method allocates the required number of blocks to store the given
  69. sequence of token IDs.
  70. Args:
  71. token_ids (List[int]): The sequence of token IDs to be stored.
  72. device (Device, optional): The device on which the blocks should be
  73. allocated. Defaults to Device.GPU.
  74. """
  75. assert not self._is_allocated
  76. assert token_ids
  77. blocks = self._allocate_blocks_for_token_ids(prev_block=None,
  78. token_ids=token_ids,
  79. device=device)
  80. self.update(blocks)
  81. self._num_full_slots = len(token_ids)
  82. def update(self, blocks: List[Block]) -> None:
  83. """Resets the table to the newly provided blocks
  84. (with their corresponding block ids)
  85. """
  86. self._blocks.update(blocks)
  87. def append_token_ids(self,
  88. token_ids: List[int],
  89. num_lookahead_slots: int = 0,
  90. num_computed_slots: Optional[int] = None) -> None:
  91. """Appends a sequence of token IDs to the existing blocks in the
  92. BlockTable.
  93. This method appends the given sequence of token IDs to the existing
  94. blocks in the BlockTable. If there is not enough space in the existing
  95. blocks, new blocks are allocated using the `ensure_num_empty_slots`
  96. method to accommodate the additional tokens.
  97. The token IDs are divided into chunks of size `block_size` (except for
  98. the first chunk, which may be smaller), and each chunk is appended to a
  99. separate block.
  100. Args:
  101. token_ids (List[int]): The sequence of token IDs to be appended.
  102. num_computed_slots (Optional[int]): The number of KV cache slots
  103. that are already filled (computed).
  104. When sliding window is enabled, this is used to compute how many
  105. blocks to drop at the front of the sequence.
  106. Without sliding window, None can be passed.
  107. Without chunked prefill, it should be the same as
  108. _num_full_slots.
  109. """
  110. assert self._is_allocated, "no blocks have been allocated"
  111. assert len(self._blocks) > 0
  112. # Drop blocks that are no longer needed due to sliding window
  113. if self._max_block_sliding_window is not None:
  114. null_block = self._allocator.allocate_or_get_null_block()
  115. assert num_computed_slots is not None
  116. end_block_idx = (num_computed_slots //
  117. self._block_size) - self._max_block_sliding_window
  118. for idx in range(0, end_block_idx):
  119. b = self._blocks[idx]
  120. if b is not null_block:
  121. self._allocator.free(b)
  122. self._blocks[idx] = null_block
  123. # Ensure there are enough empty slots for the new tokens plus
  124. # lookahead slots
  125. self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
  126. num_lookahead_slots)
  127. # Update the blocks with the new tokens
  128. first_block_idx = self._num_full_slots // self._block_size
  129. token_blocks = self._chunk_token_blocks_for_append(token_ids)
  130. for i, token_block in enumerate(token_blocks):
  131. self._blocks.append_token_ids(first_block_idx + i, token_block)
  132. self._num_full_slots += len(token_ids)
  133. def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
  134. """Ensures that the BlockTable has at least the specified number of
  135. empty slots available.
  136. This method checks if the BlockTable has enough empty slots (i.e.,
  137. available space) to accommodate the requested number of tokens. If not,
  138. it allocates additional blocks on the GPU to ensure that the required
  139. number of empty slots is available.
  140. Args:
  141. num_empty_slots (int): The minimum number of empty slots required.
  142. """
  143. # Currently the block table only supports
  144. # appending tokens to GPU blocks.
  145. device = Device.GPU
  146. assert self._is_allocated
  147. if self._num_empty_slots >= num_empty_slots:
  148. return
  149. slots_to_allocate = num_empty_slots - self._num_empty_slots
  150. blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
  151. for _ in range(blocks_to_allocate):
  152. assert len(self._blocks) > 0
  153. self._blocks.append(
  154. self._allocator.allocate_mutable_block(
  155. prev_block=self._blocks[-1], device=device))
  156. def fork(self) -> "BlockTable":
  157. """Creates a new BlockTable instance with a copy of the blocks from the
  158. current instance.
  159. This method creates a new BlockTable instance with the same block size,
  160. block allocator, and a copy of the blocks from the current instance. The
  161. new BlockTable has its own independent set of blocks, but shares the
  162. same underlying memory allocation with the original BlockTable.
  163. Returns:
  164. BlockTable: A new BlockTable instance with a copy of the blocks from
  165. the current instance.
  166. """
  167. assert self._is_allocated
  168. assert len(self._blocks) > 0
  169. forked_blocks = self._allocator.fork(self._blocks[-1])
  170. return BlockTable(
  171. block_size=self._block_size,
  172. block_allocator=self._allocator,
  173. _blocks=forked_blocks,
  174. max_block_sliding_window=self._max_block_sliding_window,
  175. )
  176. def free(self) -> None:
  177. """Frees the memory occupied by the blocks in the BlockTable.
  178. This method iterates over all the blocks in the `_blocks` list and calls
  179. the `free` method of the `_allocator` object to release the memory
  180. occupied by each block. After freeing all the blocks, the `_blocks` list
  181. is set to `None`.
  182. """
  183. assert self._is_allocated
  184. for block in self.blocks:
  185. self._allocator.free(block)
  186. self._blocks.reset()
  187. @property
  188. def physical_block_ids(self) -> List[int]:
  189. """Returns a list of physical block indices for the blocks in the
  190. BlockTable.
  191. This property returns a list of integers, where each integer represents
  192. the physical block index of a corresponding block in the `_blocks` list.
  193. The physical block index is a unique identifier for the memory location
  194. occupied by the block.
  195. Returns:
  196. List[int]: A list of physical block indices for the blocks in the
  197. BlockTable.
  198. """
  199. assert self._is_allocated
  200. return self._blocks.ids()
  201. def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
  202. """Get the number of "unseen" tokens in the sequence.
  203. Unseen tokens are tokens in the sequence corresponding to this block
  204. table, but are not yet appended to this block table.
  205. Args:
  206. sequence_token_ids (List[int]): The list of token ids in the
  207. sequence.
  208. Returns:
  209. List[int]: The postfix of sequence_token_ids that has not yet been
  210. appended to the block table.
  211. """
  212. # Since the block table is append-only, the unseen token ids are the
  213. # ones after the appended ones.
  214. return sequence_token_ids[self.num_full_slots:]
  215. def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
  216. token_ids: List[int],
  217. device: Device) -> List[Block]:
  218. blocks: List[Block] = []
  219. block_token_ids = []
  220. tail_token_ids = []
  221. for cur_token_ids in chunk_list(token_ids, self._block_size):
  222. if len(cur_token_ids) == self._block_size:
  223. block_token_ids.append(cur_token_ids)
  224. else:
  225. tail_token_ids.append(cur_token_ids)
  226. if block_token_ids:
  227. blocks.extend(
  228. self._allocator.allocate_immutable_blocks(
  229. prev_block, block_token_ids=block_token_ids,
  230. device=device))
  231. prev_block = blocks[-1]
  232. if tail_token_ids:
  233. assert len(tail_token_ids) == 1
  234. cur_token_ids = tail_token_ids[0]
  235. block = self._allocator.allocate_mutable_block(
  236. prev_block=prev_block, device=device)
  237. block.append_token_ids(cur_token_ids)
  238. blocks.append(block)
  239. return blocks
  240. def _get_all_token_ids(self) -> List[int]:
  241. # NOTE: This function is O(seq_len); use sparingly.
  242. token_ids: List[int] = []
  243. if not self._is_allocated:
  244. return token_ids
  245. for block in self.blocks:
  246. token_ids.extend(block.token_ids)
  247. return token_ids
  248. def _get_num_token_ids(self) -> int:
  249. res = 0
  250. for block in self.blocks:
  251. res += len(block.token_ids)
  252. return res
  253. @property
  254. def _is_allocated(self) -> bool:
  255. return len(self._blocks) > 0
  256. @property
  257. def blocks(self) -> List[Block]:
  258. return self._blocks.list()
  259. @property
  260. def _num_empty_slots(self) -> int:
  261. assert self._is_allocated
  262. return len(self._blocks) * self._block_size - self._num_full_slots
  263. @property
  264. def num_full_slots(self) -> int:
  265. """Returns the total number of tokens currently stored in the
  266. BlockTable.
  267. Returns:
  268. int: The total number of tokens currently stored in the BlockTable.
  269. """
  270. return self._num_full_slots
  271. def get_num_blocks_touched_by_append_slots(
  272. self, token_ids: List[int], num_lookahead_slots: int) -> int:
  273. """Determine how many blocks will be "touched" by appending the token
  274. ids.
  275. This is required for the scheduler to determine whether a sequence can
  276. continue generation, or if it must be preempted.
  277. """
  278. # Math below is equivalent to:
  279. # all_token_ids = token_ids + [-1] * num_lookahead_slots
  280. # token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
  281. # return len(token_blocks)
  282. num_token_ids = len(token_ids) + num_lookahead_slots
  283. first_chunk_size = self._block_size - (self._num_full_slots %
  284. self._block_size)
  285. num_token_blocks = (1 + math.ceil(
  286. (num_token_ids - first_chunk_size) / self._block_size))
  287. return num_token_blocks
  288. def _chunk_token_blocks_for_append(
  289. self, token_ids: List[int]) -> List[List[int]]:
  290. """Split the token ids into block-sized chunks so they can be easily
  291. appended to blocks. The first such "token block" may have less token ids
  292. than the block size, since the last allocated block may be partially
  293. full.
  294. """
  295. first_chunk_size = self._block_size - (self._num_full_slots %
  296. self._block_size)
  297. token_blocks = [token_ids[:first_chunk_size]]
  298. token_blocks.extend(
  299. chunk_list(token_ids[first_chunk_size:], self._block_size))
  300. return token_blocks