block_table.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. from typing import List, Optional
  2. from aphrodite.processing.block.interfaces import (
  3. Block,
  4. DeviceAwareBlockAllocator,
  5. )
  6. from aphrodite.common.utils import Device, cdiv, chunk_list
  7. class BlockTable:
  8. """A class to manage blocks for a specific sequence.
  9. The BlockTable maps a sequence of tokens to a list of blocks, where each
  10. block represents a contiguous memory allocation for a portion of the
  11. sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
  12. responsible for allocating and freeing memory for the blocks.
  13. Args:
  14. block_size (int): The maximum number of tokens that can be stored in a
  15. single block.
  16. block_allocator (DeviceAwareBlockAllocator): The block allocator used to
  17. manage memory for the blocks.
  18. _blocks (Optional[List[Block]], optional): An optional list of existing
  19. blocks to initialize the BlockTable with. If not provided, an empty
  20. BlockTable is created.
  21. Attributes:
  22. _block_size (int): The maximum number of tokens that can be stored in a
  23. single block.
  24. _allocator (DeviceAwareBlockAllocator): The block allocator used to
  25. manage memory for the blocks.
  26. _blocks (Optional[List[Block]]): The list of blocks managed by this
  27. BlockTable.
  28. _num_full_slots (int): The number of tokens currently stored in the
  29. blocks.
  30. """
  31. def __init__(
  32. self,
  33. block_size: int,
  34. block_allocator: DeviceAwareBlockAllocator,
  35. _blocks: Optional[List[Block]] = None,
  36. ):
  37. self._block_size = block_size
  38. self._allocator = block_allocator
  39. self._blocks: Optional[List[Block]] = _blocks
  40. # Use helper method instead of directly calculating, as blocks
  41. # may not be allocated.
  42. self._num_full_slots = len(self._get_all_token_ids())
  43. @staticmethod
  44. def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
  45. """Calculates the minimum number of blocks required to store a given
  46. sequence of token IDs.
  47. This assumes worst-case scenario, where every block requires a new
  48. allocation (e.g. ignoring prefix caching).
  49. Args:
  50. token_ids (List[int]): The sequence of token IDs to be stored.
  51. block_size (int): The maximum number of tokens that can be stored in
  52. a single block.
  53. Returns:
  54. int: The minimum number of blocks required to store the given
  55. sequence of token IDs.
  56. """
  57. return cdiv(len(token_ids), block_size)
  58. def allocate(self,
  59. token_ids: List[int],
  60. device: Device = Device.GPU) -> None:
  61. """Allocates memory blocks for storing the given sequence of token IDs.
  62. This method allocates the required number of blocks to store the given
  63. sequence of token IDs.
  64. Args:
  65. token_ids (List[int]): The sequence of token IDs to be stored.
  66. device (Device, optional): The device on which the blocks should be
  67. allocated. Defaults to Device.GPU.
  68. """
  69. assert not self._is_allocated
  70. assert token_ids
  71. self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
  72. token_ids=token_ids,
  73. device=device)
  74. self._num_full_slots = len(token_ids)
  75. def append_token_ids(self,
  76. token_ids: List[int],
  77. num_lookahead_slots: int = 0) -> None:
  78. """Appends a sequence of token IDs to the existing blocks in the
  79. BlockTable.
  80. This method appends the given sequence of token IDs to the existing
  81. blocks in the BlockTable. If there is not enough space in the existing
  82. blocks, new blocks are allocated using the `ensure_num_empty_slots`
  83. method to accommodate the additional tokens.
  84. The token IDs are divided into chunks of size `block_size` (except for
  85. the first chunk, which may be smaller), and each chunk is appended to a
  86. separate block.
  87. Args:
  88. token_ids (List[int]): The sequence of token IDs to be appended.
  89. """
  90. assert self._is_allocated
  91. assert token_ids, "can't append empty token ids"
  92. self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
  93. num_lookahead_slots)
  94. blocks = self._blocks[self._num_full_slots // self._block_size:]
  95. token_blocks = self._chunk_token_blocks_for_append(token_ids)
  96. for block, token_block in zip(blocks, token_blocks):
  97. block.append_token_ids(token_block)
  98. self._num_full_slots += len(token_ids)
  99. def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
  100. """Ensures that the BlockTable has at least the specified number of
  101. empty slots available.
  102. This method checks if the BlockTable has enough empty slots (i.e.,
  103. available space) to accommodate the requested number of tokens. If not,
  104. it allocates additional blocks on the GPU to ensure that the required
  105. number of empty slots is available.
  106. Args:
  107. num_empty_slots (int): The minimum number of empty slots required.
  108. """
  109. # Currently the block table only supports
  110. # appending tokens to GPU blocks.
  111. device = Device.GPU
  112. assert self._is_allocated
  113. if self._num_empty_slots >= num_empty_slots:
  114. return
  115. slots_to_allocate = num_empty_slots - self._num_empty_slots
  116. blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
  117. for _ in range(blocks_to_allocate):
  118. self._blocks.append(
  119. self._allocator.allocate_mutable(prev_block=self._blocks[-1],
  120. device=device))
  121. def fork(self) -> "BlockTable":
  122. """Creates a new BlockTable instance with a copy of the blocks from the
  123. current instance.
  124. This method creates a new BlockTable instance with the same block size,
  125. block allocator, and a copy of the blocks from the current instance. The
  126. new BlockTable has its own independent set of blocks, but shares the
  127. same underlying memory allocation with the original BlockTable.
  128. Returns:
  129. BlockTable: A new BlockTable instance with a copy of the blocks from
  130. the current instance.
  131. """
  132. assert self._is_allocated
  133. forked_blocks = self._allocator.fork(self._blocks[-1])
  134. return BlockTable(
  135. block_size=self._block_size,
  136. block_allocator=self._allocator,
  137. _blocks=forked_blocks,
  138. )
  139. def free(self) -> None:
  140. """Frees the memory occupied by the blocks in the BlockTable.
  141. This method iterates over all the blocks in the `_blocks` list and calls
  142. the `free` method of the `_allocator` object to release the memory
  143. occupied by each block. After freeing all the blocks, the `_blocks` list
  144. is set to `None`.
  145. """
  146. assert self._is_allocated
  147. for block in self._blocks:
  148. self._allocator.free(block)
  149. self._blocks = None
  150. @property
  151. def physical_block_ids(self) -> List[int]:
  152. """Returns a list of physical block indices for the blocks in the
  153. BlockTable.
  154. This property returns a list of integers, where each integer represents
  155. the physical block index of a corresponding block in the `_blocks` list.
  156. The physical block index is a unique identifier for the memory location
  157. occupied by the block.
  158. Returns:
  159. List[int]: A list of physical block indices for the blocks in the
  160. BlockTable.
  161. """
  162. assert self._is_allocated
  163. return [block.block_id for block in self._blocks]
  164. def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
  165. """Get the number of "unseen" tokens in the sequence.
  166. Unseen tokens are tokens in the sequence corresponding to this block
  167. table, but are not yet appended to this block table.
  168. Args:
  169. sequence_token_ids (List[int]): The list of token ids in the
  170. sequence.
  171. Returns:
  172. List[int]: The postfix of sequence_token_ids that has not yet been
  173. appended to the block table.
  174. """
  175. # Since the block table is append-only, the unseen token ids are the
  176. # ones after the appended ones.
  177. return sequence_token_ids[self.num_full_slots:]
  178. def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
  179. token_ids: List[int],
  180. device: Device) -> List[Block]:
  181. blocks = []
  182. for block_token_ids in chunk_list(token_ids, self._block_size):
  183. if len(block_token_ids) == self._block_size:
  184. # If the block is full, create an immutable block.
  185. prev_block = self._allocator.allocate_immutable(
  186. prev_block, token_ids=block_token_ids, device=device)
  187. else:
  188. # Else, partially fill a mutable block with token ids.
  189. prev_block = self._allocator.allocate_mutable(
  190. prev_block=prev_block, device=device)
  191. prev_block.append_token_ids(block_token_ids)
  192. blocks.append(prev_block)
  193. return blocks
  194. def _get_all_token_ids(self) -> List[int]:
  195. # NOTE: This function is O(seq_len); use sparingly.
  196. token_ids = []
  197. if not self._is_allocated:
  198. return token_ids
  199. for block in self._blocks:
  200. token_ids.extend(block.token_ids)
  201. return token_ids
  202. @property
  203. def _is_allocated(self) -> bool:
  204. return self._blocks is not None
  205. @property
  206. def _num_empty_slots(self) -> int:
  207. assert self._is_allocated
  208. return len(self._blocks) * self._block_size - self._num_full_slots
  209. @property
  210. def num_full_slots(self) -> int:
  211. """Returns the total number of tokens currently stored in the
  212. BlockTable.
  213. Returns:
  214. int: The total number of tokens currently stored in the BlockTable.
  215. """
  216. return self._num_full_slots
  217. def get_num_blocks_touched_by_append_slots(
  218. self, token_ids: List[int], num_lookahead_slots: int) -> int:
  219. """Determine how many blocks will be "touched" by appending the token
  220. ids.
  221. This is required for the scheduler to determine whether a sequence can
  222. continue generation, or if it must be preempted.
  223. """
  224. all_token_ids = token_ids + [-1] * num_lookahead_slots
  225. token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
  226. return len(token_blocks)
  227. def _chunk_token_blocks_for_append(
  228. self, token_ids: List[int]) -> List[List[int]]:
  229. """Split the token ids into block-sized chunks so they can be easily
  230. appended to blocks. The first such "token block" may have less token ids
  231. than the block size, since the last allocated block may be partially
  232. full.
  233. """
  234. first_chunk_size = self._block_size - (self._num_full_slots %
  235. self._block_size)
  236. token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
  237. token_ids[first_chunk_size:], self._block_size)
  238. return token_blocks