naive_block.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. from collections import deque
  2. from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
  3. from aphrodite.common.utils import cdiv
  4. from aphrodite.processing.block.common import (BlockPool, CopyOnWriteTracker,
  5. RefCounter,
  6. get_all_blocks_recursively)
  7. from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
  8. BlockId, Device)
  9. Refcount = int
  10. class NaiveBlockAllocator(BlockAllocator):
  11. """A simple block allocator that manages blocks of memory without prefix
  12. caching.
  13. Args:
  14. create_block (Block.Factory): A factory function for creating new
  15. blocks. This is used when a NaiveBlockAllocator is composed within
  16. a prefix caching allocator -- the naive block allocator must
  17. construct prefix caching blocks (but shouldn't know anything else
  18. about them).
  19. num_blocks (int): The total number of blocks to manage.
  20. block_size (int): The size of each block in tokens.
  21. block_ids (Optional[Iterable[int]], optional): An optional iterable of
  22. block IDs. If not provided, block IDs will be assigned sequentially
  23. from 0 to num_blocks - 1.
  24. """
  25. def __init__(
  26. self,
  27. create_block: Block.Factory,
  28. num_blocks: int,
  29. block_size: int,
  30. block_ids: Optional[Iterable[int]] = None,
  31. block_pool: Optional[BlockPool] = None,
  32. ):
  33. if block_ids is None:
  34. block_ids = range(num_blocks)
  35. self._free_block_indices: Deque[BlockId] = deque(block_ids)
  36. self._all_block_indices = frozenset(block_ids)
  37. assert len(self._all_block_indices) == num_blocks
  38. self._refcounter = RefCounter(
  39. all_block_indices=self._free_block_indices)
  40. self._block_size = block_size
  41. self._cow_tracker = CopyOnWriteTracker(
  42. refcounter=self._refcounter.as_readonly())
  43. if block_pool is None:
  44. extra_factor = 4
  45. # Pre-allocate "num_blocks * extra_factor" block objects.
  46. # The "* extra_factor" is a buffer to allow more block objects
  47. # than physical blocks
  48. self._block_pool = BlockPool(self._block_size, create_block, self,
  49. num_blocks * extra_factor)
  50. else:
  51. # In this case, the block pool is provided by the caller,
  52. # which means that there is most likely a need to share
  53. # a block pool between allocators
  54. self._block_pool = block_pool
  55. def allocate_immutable_block(self,
  56. prev_block: Optional[Block],
  57. token_ids: List[int],
  58. device: Optional[Device] = None) -> Block:
  59. """Allocates a new immutable block with the given token IDs, linked to
  60. the previous block.
  61. Args:
  62. prev_block (Optional[Block]): The previous block in the sequence. If
  63. None, then the block to be allocated is the first block in the
  64. sequence.
  65. token_ids (List[int]): The token IDs to be stored in the new block.
  66. Returns:
  67. Block: The newly allocated immutable block.
  68. """
  69. assert device is None
  70. block = self.allocate_mutable_block(prev_block=prev_block)
  71. block.append_token_ids(token_ids)
  72. return block
  73. def allocate_immutable_blocks(
  74. self,
  75. prev_block: Optional[Block],
  76. block_token_ids: List[List[int]],
  77. device: Optional[Device] = None) -> List[Block]:
  78. assert device is None
  79. num_blocks = len(block_token_ids)
  80. block_ids = []
  81. for i in range(num_blocks):
  82. block_ids.append(self._allocate_block_id())
  83. blocks = []
  84. for i in range(num_blocks):
  85. prev_block = self._block_pool.init_block(
  86. prev_block=prev_block,
  87. token_ids=block_token_ids[i],
  88. block_size=self._block_size,
  89. physical_block_id=block_ids[i])
  90. blocks.append(prev_block)
  91. return blocks
  92. def allocate_mutable_block(self,
  93. prev_block: Optional[Block],
  94. device: Optional[Device] = None) -> Block:
  95. """Allocates a new mutable block, linked to the previous block.
  96. Args:
  97. prev_block (Optional[Block]): The previous block in the sequence. If
  98. None, then the block to be allocated is the first block in the
  99. sequence.
  100. Returns:
  101. Block: The newly allocated mutable block.
  102. """
  103. assert device is None
  104. block_id = self._allocate_block_id()
  105. block = self._block_pool.init_block(prev_block=prev_block,
  106. token_ids=[],
  107. block_size=self._block_size,
  108. physical_block_id=block_id)
  109. return block
  110. def _allocate_block_id(self) -> BlockId:
  111. if not self._free_block_indices:
  112. raise BlockAllocator.NoFreeBlocksError()
  113. block_id = self._free_block_indices.popleft()
  114. self._refcounter.incr(block_id)
  115. return block_id
  116. def _free_block_id(self, block: Block) -> None:
  117. block_id = block.block_id
  118. assert block_id is not None
  119. refcount = self._refcounter.decr(block_id)
  120. if refcount == 0:
  121. self._free_block_indices.appendleft(block_id)
  122. block.block_id = None
  123. def free(self, block: Block, keep_block_object: bool = False) -> None:
  124. # Release the physical block id
  125. self._free_block_id(block)
  126. # Release the block object
  127. if not keep_block_object:
  128. self._block_pool.free_block(block)
  129. def fork(self, last_block: Block) -> List[Block]:
  130. """Creates a new sequence of blocks that shares the same underlying
  131. memory as the original sequence.
  132. Args:
  133. last_block (Block): The last block in the original sequence.
  134. Returns:
  135. List[Block]: The new sequence of blocks that shares the same memory
  136. as the original sequence.
  137. """
  138. source_blocks = get_all_blocks_recursively(last_block)
  139. forked_blocks: List[Block] = []
  140. prev_block = None
  141. for block in source_blocks:
  142. # Increment refcount for each block.
  143. assert block.block_id is not None
  144. refcount = self._refcounter.incr(block.block_id)
  145. assert refcount != 1, "can't fork free'd block"
  146. forked_block = self._block_pool.init_block(
  147. prev_block=prev_block,
  148. token_ids=block.token_ids,
  149. block_size=self._block_size,
  150. physical_block_id=block.block_id)
  151. forked_blocks.append(forked_block)
  152. prev_block = forked_blocks[-1]
  153. return forked_blocks
  154. def get_num_free_blocks(self) -> int:
  155. return len(self._free_block_indices)
  156. def get_num_total_blocks(self) -> int:
  157. return len(self._all_block_indices)
  158. def get_physical_block_id(self, absolute_id: int) -> int:
  159. """Returns the zero-offset block id on certain block allocator
  160. given the absolute block id.
  161. Args:
  162. absolute_id (int): The absolute block id for the block
  163. in whole allocator.
  164. Returns:
  165. int: The zero-offset block id on certain device.
  166. """
  167. return sorted(self._all_block_indices).index(absolute_id)
  168. @property
  169. def refcounter(self):
  170. return self._refcounter
  171. @property
  172. def all_block_ids(self) -> FrozenSet[int]:
  173. return self._all_block_indices
  174. def cow_block_if_not_appendable(self, block: Block) -> BlockId:
  175. """Performs a copy-on-write operation on the given block if it is not
  176. appendable.
  177. Args:
  178. block (Block): The block to check for copy-on-write.
  179. Returns:
  180. BlockId: The block index of the new block if a copy-on-write
  181. operation was performed, or the original block index if
  182. no copy-on-write was necessary.
  183. """
  184. src_block_id = block.block_id
  185. assert src_block_id is not None
  186. if self._cow_tracker.is_appendable(block):
  187. return src_block_id
  188. self._free_block_id(block)
  189. trg_block_id = self._allocate_block_id()
  190. self._cow_tracker.record_cow(src_block_id, trg_block_id)
  191. return trg_block_id
  192. def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
  193. """Returns the copy-on-write source->destination mapping and clears it.
  194. Returns:
  195. List[Tuple[BlockId, BlockId]]: A list mapping source
  196. block indices to destination block indices.
  197. """
  198. return self._cow_tracker.clear_cows()
  199. def mark_blocks_as_accessed(self, block_ids: List[int],
  200. now: float) -> None:
  201. """Mark blocks as accessed, used in prefix caching.
  202. Since the naive allocator does not implement prefix caching, we do
  203. nothing.
  204. """
  205. pass
  206. def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
  207. """Mark blocks as computed, used in prefix caching.
  208. Since the naive allocator does not implement prefix caching, we do
  209. nothing.
  210. """
  211. pass
  212. def get_computed_block_ids(self, prev_computed_block_ids: List[int],
  213. block_ids: List[int],
  214. skip_last_block_id: bool) -> List[int]:
  215. """No prefix caching here => return empty list
  216. """
  217. return []
  218. def get_common_computed_block_ids(
  219. self, computed_seq_block_ids: List[List[int]]) -> List[int]:
  220. """Determine blocks that can be skipped in prefill.
  221. Since the naive allocator does not support prefix caching, always return
  222. an empty list.
  223. """
  224. return []
  225. def promote_to_immutable_block(self, block: Block) -> BlockId:
  226. raise NotImplementedError("There is no promotion for naive blocks")
  227. def get_num_blocks_touched(self,
  228. blocks: List[Block],
  229. num_lookahead_slots: int = 0) -> int:
  230. """Determine the number of blocks that will be touched by
  231. swapping in/out the given blocks from certain sequence
  232. group with the provided num_lookahead_slots.
  233. Args:
  234. blocks (List[Block]): The potential blocks to swap.
  235. num_lookahead_slots (int): number of lookahead slots (0 for swap
  236. out).
  237. Returns:
  238. int: the number of blocks that will be touched by
  239. swapping in/out the given blocks and num_lookahead_slots.
  240. """
  241. # NOTE: for naive block, we use set to eliminate common blocks among
  242. # seqs, also we compare the empty slots in the mutable blocks with
  243. # lookahead slots to get the number of unique new block that are
  244. # needed.
  245. old_block_set = set()
  246. new_block_count = 0
  247. # TODO: make sure the logic is correct and clean it up.
  248. for block in blocks:
  249. if not block.is_full and num_lookahead_slots != 0:
  250. new_block_count += 1
  251. if num_lookahead_slots > block.num_empty_slots:
  252. new_block_count += cdiv(
  253. num_lookahead_slots - block.num_empty_slots,
  254. self._block_size)
  255. else:
  256. old_block_set.add(block.block_id)
  257. num_touched_blocks = new_block_count + len(old_block_set)
  258. return num_touched_blocks
  259. def swap_out(self, blocks: List[Block]) -> None:
  260. for block in blocks:
  261. self._free_block_id(block)
  262. def swap_in(self, blocks: List[Block]) -> None:
  263. for block in blocks:
  264. # Here we allocate either immutable or mutable block and then
  265. # extract its block_id. Note that the block object is released
  266. # and the block_id is assigned to "block" to allow reusing the
  267. # existing "block" object
  268. if block.is_full:
  269. tmp_block = self.allocate_immutable_block(
  270. prev_block=block.prev_block, token_ids=block.token_ids)
  271. else:
  272. tmp_block = self.allocate_mutable_block(
  273. prev_block=block.prev_block)
  274. tmp_block.append_token_ids(block.token_ids)
  275. block_id = tmp_block.block_id
  276. tmp_block.block_id = None
  277. self._block_pool.free_block(tmp_block)
  278. block.block_id = block_id # Assign block_id
  279. class NaiveBlock(Block):
  280. """An implementation of the Block class that does not support prefix
  281. caching.
  282. The NaiveBlock class represents a block of token IDs with a fixed size. It
  283. provides methods for appending token IDs to the block and manages copy-on
  284. -write operations when necessary.
  285. Args:
  286. prev_block (Block): The previous block in the sequence.
  287. token_ids (List[int]): The initial token IDs to be stored in the block.
  288. block_size (int): The maximum number of token IDs that can be stored in
  289. the block.
  290. allocator (BlockAllocator): The block allocator associated with this
  291. block.
  292. block_id (Optional[int], optional): The physical block index
  293. of this block. Defaults to None, which means no allocation has been
  294. made.
  295. _cow_target (Optional[Block], optional): The copy-on-write target block.
  296. If not provided, it defaults to self.
  297. """
  298. def __init__(self,
  299. prev_block: Optional[Block],
  300. token_ids: List[int],
  301. block_size: int,
  302. allocator: BlockAllocator,
  303. block_id: Optional[int] = None,
  304. _cow_target: Optional[Block] = None):
  305. self._token_ids: List[int] = []
  306. self._block_size = block_size
  307. self._prev_block = prev_block
  308. self._block_id = block_id
  309. self._allocator = allocator
  310. self._cow_target = _cow_target if _cow_target is not None else self
  311. self._append_token_ids_no_cow(token_ids)
  312. def append_token_ids(self, token_ids: List[int]) -> None:
  313. """Appends the given token IDs to the block and performs a
  314. copy-on-write if necessary.
  315. Args:
  316. token_ids (Optional[List[int]]): The token IDs to be appended
  317. to the block.
  318. """
  319. self._append_token_ids_no_cow(token_ids)
  320. if self._block_id is not None:
  321. self._block_id = (self._allocator.cow_block_if_not_appendable(
  322. self._cow_target))
  323. def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
  324. """Appends the given token IDs to the block
  325. Args:
  326. token_ids (List[int]): The token IDs to be appended to the block.
  327. """
  328. if len(token_ids) == 0:
  329. return
  330. assert len(token_ids) <= self.num_empty_slots
  331. self._token_ids.extend(token_ids)
  332. @property
  333. def computed(self) -> bool:
  334. raise NotImplementedError
  335. @computed.setter
  336. def computed(self, value) -> None:
  337. raise NotImplementedError
  338. @property
  339. def last_accessed(self) -> float:
  340. raise NotImplementedError
  341. @last_accessed.setter
  342. def last_accessed(self, last_accessed_ts: float):
  343. raise NotImplementedError
  344. @property
  345. def block_id(self) -> Optional[int]:
  346. return self._block_id
  347. @block_id.setter
  348. def block_id(self, value: Optional[int]) -> None:
  349. self._block_id = value
  350. @property
  351. def is_full(self) -> bool:
  352. return self.num_empty_slots == 0
  353. @property
  354. def num_empty_slots(self) -> int:
  355. return self._block_size - len(self.token_ids)
  356. @property
  357. def token_ids(self) -> List[int]:
  358. return self._token_ids
  359. @property
  360. def num_tokens_total(self) -> int:
  361. raise NotImplementedError(
  362. "num_tokens_total is not used for naive block")
  363. @property
  364. def block_size(self) -> int:
  365. return self._block_size
  366. @property
  367. def prev_block(self) -> Optional["Block"]:
  368. return self._prev_block
  369. @property
  370. def content_hash(self) -> Optional[int]:
  371. return None