block_table.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from typing import List, Optional
  2. from aphrodite.common.utils import Device, cdiv, chunk_list
  3. from aphrodite.processing.block.interfaces import (Block,
  4. DeviceAwareBlockAllocator)
  5. class BlockTable:
  6. """A class to manage blocks for a specific sequence.
  7. The BlockTable maps a sequence of tokens to a list of blocks, where each
  8. block represents a contiguous memory allocation for a portion of the
  9. sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
  10. responsible for allocating and freeing memory for the blocks.
  11. Args:
  12. block_size (int): The maximum number of tokens that can be stored in a
  13. single block.
  14. block_allocator (DeviceAwareBlockAllocator): The block allocator used to
  15. manage memory for the blocks.
  16. _blocks (Optional[List[Block]], optional): An optional list of existing
  17. blocks to initialize the BlockTable with. If not provided, an empty
  18. BlockTable is created.
  19. Attributes:
  20. _block_size (int): The maximum number of tokens that can be stored in a
  21. single block.
  22. _allocator (DeviceAwareBlockAllocator): The block allocator used to
  23. manage memory for the blocks.
  24. _blocks (Optional[List[Block]]): The list of blocks managed by this
  25. BlockTable.
  26. _num_full_slots (int): The number of tokens currently stored in the
  27. blocks.
  28. """
  29. def __init__(
  30. self,
  31. block_size: int,
  32. block_allocator: DeviceAwareBlockAllocator,
  33. _blocks: Optional[List[Block]] = None,
  34. ):
  35. self._block_size = block_size
  36. self._allocator = block_allocator
  37. self._blocks: Optional[List[Block]] = _blocks
  38. # Use helper method instead of directly calculating, as blocks
  39. # may not be allocated.
  40. self._num_full_slots = len(self._get_all_token_ids())
  41. @staticmethod
  42. def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
  43. """Calculates the minimum number of blocks required to store a given
  44. sequence of token IDs.
  45. This assumes worst-case scenario, where every block requires a new
  46. allocation (e.g. ignoring prefix caching).
  47. Args:
  48. token_ids (List[int]): The sequence of token IDs to be stored.
  49. block_size (int): The maximum number of tokens that can be stored in
  50. a single block.
  51. Returns:
  52. int: The minimum number of blocks required to store the given
  53. sequence of token IDs.
  54. """
  55. return cdiv(len(token_ids), block_size)
  56. def allocate(self,
  57. token_ids: List[int],
  58. device: Device = Device.GPU) -> None:
  59. """Allocates memory blocks for storing the given sequence of token IDs.
  60. This method allocates the required number of blocks to store the given
  61. sequence of token IDs.
  62. Args:
  63. token_ids (List[int]): The sequence of token IDs to be stored.
  64. device (Device, optional): The device on which the blocks should be
  65. allocated. Defaults to Device.GPU.
  66. """
  67. assert not self._is_allocated
  68. assert token_ids
  69. self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
  70. token_ids=token_ids,
  71. device=device)
  72. self._num_full_slots = len(token_ids)
  73. def append_token_ids(self,
  74. token_ids: List[int],
  75. num_lookahead_slots: int = 0) -> None:
  76. """Appends a sequence of token IDs to the existing blocks in the
  77. BlockTable.
  78. This method appends the given sequence of token IDs to the existing
  79. blocks in the BlockTable. If there is not enough space in the existing
  80. blocks, new blocks are allocated using the `ensure_num_empty_slots`
  81. method to accommodate the additional tokens.
  82. The token IDs are divided into chunks of size `block_size` (except for
  83. the first chunk, which may be smaller), and each chunk is appended to a
  84. separate block.
  85. Args:
  86. token_ids (List[int]): The sequence of token IDs to be appended.
  87. """
  88. assert self._is_allocated
  89. assert self._blocks is not None
  90. self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
  91. num_lookahead_slots)
  92. blocks = self._blocks[self._num_full_slots // self._block_size:]
  93. token_blocks = self._chunk_token_blocks_for_append(token_ids)
  94. for block, token_block in zip(blocks, token_blocks):
  95. block.append_token_ids(token_block)
  96. self._num_full_slots += len(token_ids)
  97. def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
  98. """Ensures that the BlockTable has at least the specified number of
  99. empty slots available.
  100. This method checks if the BlockTable has enough empty slots (i.e.,
  101. available space) to accommodate the requested number of tokens. If not,
  102. it allocates additional blocks on the GPU to ensure that the required
  103. number of empty slots is available.
  104. Args:
  105. num_empty_slots (int): The minimum number of empty slots required.
  106. """
  107. # Currently the block table only supports
  108. # appending tokens to GPU blocks.
  109. device = Device.GPU
  110. assert self._is_allocated
  111. if self._num_empty_slots >= num_empty_slots:
  112. return
  113. slots_to_allocate = num_empty_slots - self._num_empty_slots
  114. blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
  115. for _ in range(blocks_to_allocate):
  116. self._blocks.append(
  117. self._allocator.allocate_mutable(prev_block=self._blocks[-1],
  118. device=device))
  119. def fork(self) -> "BlockTable":
  120. """Creates a new BlockTable instance with a copy of the blocks from the
  121. current instance.
  122. This method creates a new BlockTable instance with the same block size,
  123. block allocator, and a copy of the blocks from the current instance. The
  124. new BlockTable has its own independent set of blocks, but shares the
  125. same underlying memory allocation with the original BlockTable.
  126. Returns:
  127. BlockTable: A new BlockTable instance with a copy of the blocks from
  128. the current instance.
  129. """
  130. assert self._is_allocated
  131. forked_blocks = self._allocator.fork(self._blocks[-1])
  132. return BlockTable(
  133. block_size=self._block_size,
  134. block_allocator=self._allocator,
  135. _blocks=forked_blocks,
  136. )
  137. def free(self) -> None:
  138. """Frees the memory occupied by the blocks in the BlockTable.
  139. This method iterates over all the blocks in the `_blocks` list and calls
  140. the `free` method of the `_allocator` object to release the memory
  141. occupied by each block. After freeing all the blocks, the `_blocks` list
  142. is set to `None`.
  143. """
  144. assert self._is_allocated
  145. for block in self._blocks:
  146. self._allocator.free(block)
  147. self._blocks = None
  148. @property
  149. def physical_block_ids(self) -> List[int]:
  150. """Returns a list of physical block indices for the blocks in the
  151. BlockTable.
  152. This property returns a list of integers, where each integer represents
  153. the physical block index of a corresponding block in the `_blocks` list.
  154. The physical block index is a unique identifier for the memory location
  155. occupied by the block.
  156. Returns:
  157. List[int]: A list of physical block indices for the blocks in the
  158. BlockTable.
  159. """
  160. assert self._is_allocated
  161. return [block.block_id for block in self._blocks]
  162. def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
  163. """Get the number of "unseen" tokens in the sequence.
  164. Unseen tokens are tokens in the sequence corresponding to this block
  165. table, but are not yet appended to this block table.
  166. Args:
  167. sequence_token_ids (List[int]): The list of token ids in the
  168. sequence.
  169. Returns:
  170. List[int]: The postfix of sequence_token_ids that has not yet been
  171. appended to the block table.
  172. """
  173. # Since the block table is append-only, the unseen token ids are the
  174. # ones after the appended ones.
  175. return sequence_token_ids[self.num_full_slots:]
  176. def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
  177. token_ids: List[int],
  178. device: Device) -> List[Block]:
  179. blocks = []
  180. for block_token_ids in chunk_list(token_ids, self._block_size):
  181. if len(block_token_ids) == self._block_size:
  182. # If the block is full, create an immutable block.
  183. prev_block = self._allocator.allocate_immutable(
  184. prev_block, token_ids=block_token_ids, device=device)
  185. else:
  186. # Else, partially fill a mutable block with token ids.
  187. prev_block = self._allocator.allocate_mutable(
  188. prev_block=prev_block, device=device)
  189. prev_block.append_token_ids(block_token_ids)
  190. blocks.append(prev_block)
  191. return blocks
  192. def _get_all_token_ids(self) -> List[int]:
  193. # NOTE: This function is O(seq_len); use sparingly.
  194. token_ids = []
  195. if not self._is_allocated:
  196. return token_ids
  197. for block in self._blocks:
  198. token_ids.extend(block.token_ids)
  199. return token_ids
  200. @property
  201. def _is_allocated(self) -> bool:
  202. return self._blocks is not None
  203. @property
  204. def _num_empty_slots(self) -> int:
  205. assert self._is_allocated
  206. return len(self._blocks) * self._block_size - self._num_full_slots
  207. @property
  208. def num_full_slots(self) -> int:
  209. """Returns the total number of tokens currently stored in the
  210. BlockTable.
  211. Returns:
  212. int: The total number of tokens currently stored in the BlockTable.
  213. """
  214. return self._num_full_slots
  215. def get_num_blocks_touched_by_append_slots(
  216. self, token_ids: List[int], num_lookahead_slots: int) -> int:
  217. """Determine how many blocks will be "touched" by appending the token
  218. ids.
  219. This is required for the scheduler to determine whether a sequence can
  220. continue generation, or if it must be preempted.
  221. """
  222. all_token_ids = token_ids + [-1] * num_lookahead_slots
  223. token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
  224. return len(token_blocks)
  225. def _chunk_token_blocks_for_append(
  226. self, token_ids: List[int]) -> List[List[int]]:
  227. """Split the token ids into block-sized chunks so they can be easily
  228. appended to blocks. The first such "token block" may have less token ids
  229. than the block size, since the last allocated block may be partially
  230. full.
  231. """
  232. first_chunk_size = self._block_size - (self._num_full_slots %
  233. self._block_size)
  234. token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
  235. token_ids[first_chunk_size:], self._block_size)
  236. return token_blocks