1
0

naive_block.py 16 KB


  1. from collections import deque
  2. from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
  3. from aphrodite.common.utils import cdiv
  4. from aphrodite.processing.block.common import (BlockPool, CopyOnWriteTracker,
  5. RefCounter,
  6. get_all_blocks_recursively)
  7. from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
  8. BlockId, Device)
  9. Refcount = int
  10. class NaiveBlockAllocator(BlockAllocator):
  11. """A simple block allocator that manages blocks of memory without prefix
  12. caching.
  13. Args:
  14. create_block (Block.Factory): A factory function for creating new
  15. blocks. This is used when a NaiveBlockAllocator is composed within
  16. a prefix caching allocator -- the naive block allocator must
  17. construct prefix caching blocks (but shouldn't know anything else
  18. about them).
  19. num_blocks (int): The total number of blocks to manage.
  20. block_size (int): The size of each block in tokens.
  21. block_ids (Optional[Iterable[int]], optional): An optional iterable of
  22. block IDs. If not provided, block IDs will be assigned sequentially
  23. from 0 to num_blocks - 1.
  24. """
  25. def __init__(
  26. self,
  27. create_block: Block.Factory,
  28. num_blocks: int,
  29. block_size: int,
  30. block_ids: Optional[Iterable[int]] = None,
  31. block_pool: Optional[BlockPool] = None,
  32. ):
  33. if block_ids is None:
  34. block_ids = range(num_blocks)
  35. self._free_block_indices: Deque[BlockId] = deque(block_ids)
  36. self._all_block_indices = frozenset(block_ids)
  37. assert len(self._all_block_indices) == num_blocks
  38. self._refcounter = RefCounter(
  39. all_block_indices=self._free_block_indices)
  40. self._block_size = block_size
  41. self._cow_tracker = CopyOnWriteTracker(
  42. refcounter=self._refcounter.as_readonly())
  43. if block_pool is None:
  44. extra_factor = 4
  45. # Pre-allocate "num_blocks * extra_factor" block objects.
  46. # The "* extra_factor" is a buffer to allow more block objects
  47. # than physical blocks
  48. self._block_pool = BlockPool(self._block_size, create_block, self,
  49. num_blocks * extra_factor)
  50. else:
  51. # In this case, the block pool is provided by the caller,
  52. # which means that there is most likely a need to share
  53. # a block pool between allocators
  54. self._block_pool = block_pool
  55. def allocate_immutable_block(self,
  56. prev_block: Optional[Block],
  57. token_ids: List[int],
  58. device: Optional[Device] = None) -> Block:
  59. """Allocates a new immutable block with the given token IDs, linked to
  60. the previous block.
  61. Args:
  62. prev_block (Optional[Block]): The previous block in the sequence. If
  63. None, then the block to be allocated is the first block in the
  64. sequence.
  65. token_ids (List[int]): The token IDs to be stored in the new block.
  66. Returns:
  67. Block: The newly allocated immutable block.
  68. """
  69. assert device is None
  70. block = self.allocate_mutable_block(prev_block=prev_block)
  71. block.append_token_ids(token_ids)
  72. return block
  73. def allocate_immutable_blocks(
  74. self,
  75. prev_block: Optional[Block],
  76. block_token_ids: List[List[int]],
  77. device: Optional[Device] = None) -> List[Block]:
  78. assert device is None
  79. num_blocks = len(block_token_ids)
  80. block_ids = []
  81. for i in range(num_blocks):
  82. block_ids.append(self._allocate_block_id())
  83. blocks = []
  84. for i in range(num_blocks):
  85. prev_block = self._block_pool.init_block(
  86. prev_block=prev_block,
  87. token_ids=block_token_ids[i],
  88. block_size=self._block_size,
  89. physical_block_id=block_ids[i])
  90. blocks.append(prev_block)
  91. return blocks
  92. def allocate_mutable_block(self,
  93. prev_block: Optional[Block],
  94. device: Optional[Device] = None) -> Block:
  95. """Allocates a new mutable block, linked to the previous block.
  96. Args:
  97. prev_block (Optional[Block]): The previous block in the sequence. If
  98. None, then the block to be allocated is the first block in the
  99. sequence.
  100. Returns:
  101. Block: The newly allocated mutable block.
  102. """
  103. assert device is None
  104. block_id = self._allocate_block_id()
  105. block = self._block_pool.init_block(prev_block=prev_block,
  106. token_ids=[],
  107. block_size=self._block_size,
  108. physical_block_id=block_id)
  109. return block
  110. def _allocate_block_id(self) -> BlockId:
  111. if not self._free_block_indices:
  112. raise BlockAllocator.NoFreeBlocksError()
  113. block_id = self._free_block_indices.popleft()
  114. self._refcounter.incr(block_id)
  115. return block_id
  116. def _free_block_id(self, block: Block) -> None:
  117. block_id = block.block_id
  118. assert block_id is not None
  119. refcount = self._refcounter.decr(block_id)
  120. if refcount == 0:
  121. self._free_block_indices.appendleft(block_id)
  122. block.block_id = None
  123. def free(self, block: Block, keep_block_object: bool = False) -> None:
  124. # Release the physical block id
  125. self._free_block_id(block)
  126. # Release the block object
  127. if not keep_block_object:
  128. self._block_pool.free_block(block)
  129. def fork(self, last_block: Block) -> List[Block]:
  130. """Creates a new sequence of blocks that shares the same underlying
  131. memory as the original sequence.
  132. Args:
  133. last_block (Block): The last block in the original sequence.
  134. Returns:
  135. List[Block]: The new sequence of blocks that shares the same memory
  136. as the original sequence.
  137. """
  138. source_blocks = get_all_blocks_recursively(last_block)
  139. forked_blocks: List[Block] = []
  140. prev_block = None
  141. for block in source_blocks:
  142. # Increment refcount for each block.
  143. assert block.block_id is not None
  144. refcount = self._refcounter.incr(block.block_id)
  145. assert refcount != 1, "can't fork free'd block"
  146. forked_block = self._block_pool.init_block(
  147. prev_block=prev_block,
  148. token_ids=block.token_ids,
  149. block_size=self._block_size,
  150. physical_block_id=block.block_id)
  151. forked_blocks.append(forked_block)
  152. prev_block = forked_blocks[-1]
  153. return forked_blocks
  154. def get_num_free_blocks(self) -> int:
  155. return len(self._free_block_indices)
  156. def get_num_total_blocks(self) -> int:
  157. return len(self._all_block_indices)
  158. def get_physical_block_id(self, absolute_id: int) -> int:
  159. """Returns the zero-offset block id on certain block allocator
  160. given the absolute block id.
  161. Args:
  162. absolute_id (int): The absolute block id for the block
  163. in whole allocator.
  164. Returns:
  165. int: The zero-offset block id on certain device.
  166. """
  167. return sorted(self._all_block_indices).index(absolute_id)
  168. @property
  169. def refcounter(self):
  170. return self._refcounter
  171. @property
  172. def all_block_ids(self) -> FrozenSet[int]:
  173. return self._all_block_indices
  174. def cow_block_if_not_appendable(self, block: Block) -> BlockId:
  175. """Performs a copy-on-write operation on the given block if it is not
  176. appendable.
  177. Args:
  178. block (Block): The block to check for copy-on-write.
  179. Returns:
  180. BlockId: The block index of the new block if a copy-on-write
  181. operation was performed, or the original block index if
  182. no copy-on-write was necessary.
  183. """
  184. src_block_id = block.block_id
  185. assert src_block_id is not None
  186. if self._cow_tracker.is_appendable(block):
  187. return src_block_id
  188. self._free_block_id(block)
  189. trg_block_id = self._allocate_block_id()
  190. self._cow_tracker.record_cow(src_block_id, trg_block_id)
  191. return trg_block_id
  192. def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
  193. """Returns the copy-on-write source->destination mapping and clears it.
  194. Returns:
  195. List[Tuple[BlockId, BlockId]]: A list mapping source
  196. block indices to destination block indices.
  197. """
  198. return self._cow_tracker.clear_cows()
  199. def mark_blocks_as_accessed(self, block_ids: List[int],
  200. now: float) -> None:
  201. """Mark blocks as accessed, used in prefix caching.
  202. Since the naive allocator does not implement prefix caching, we do
  203. nothing.
  204. """
  205. pass
  206. def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
  207. """Mark blocks as computed, used in prefix caching.
  208. Since the naive allocator does not implement prefix caching, we do
  209. nothing.
  210. """
  211. pass
  212. def get_computed_block_ids(self, prev_computed_block_ids: List[int],
  213. block_ids: List[int],
  214. skip_last_block_id: bool) -> List[int]:
  215. """No prefix caching here => return empty list
  216. """
  217. return []
  218. def get_common_computed_block_ids(
  219. self, computed_seq_block_ids: List[List[int]]) -> List[int]:
  220. """Determine blocks that can be skipped in prefill.
  221. Since the naive allocator does not support prefix caching, always return
  222. an empty list.
  223. """
  224. return []
  225. def promote_to_immutable_block(self, block: Block) -> BlockId:
  226. raise NotImplementedError("There is no promotion for naive blocks")
  227. def get_num_blocks_touched(self,
  228. blocks: List[Block],
  229. num_lookahead_slots: int = 0) -> int:
  230. """Determine the number of blocks that will be touched by
  231. swapping in/out the given blocks from certain sequence
  232. group with the provided num_lookahead_slots.
  233. Args:
  234. blocks (List[Block]): The potential blocks to swap.
  235. num_lookahead_slots (int): number of lookahead slots (0 for swap
  236. out).
  237. Returns:
  238. int: the number of blocks that will be touched by
  239. swapping in/out the given blocks and num_lookahead_slots.
  240. """
  241. # NOTE: for naive block, we use set to eliminate common blocks among
  242. # seqs, also we compare the empty slots in the mutable blocks with
  243. # lookahead slots to get the number of unique new block that are
  244. # needed.
  245. old_block_set = set()
  246. new_block_count = 0
  247. # TODO: make sure the logic is correct and clean it up.
  248. for block in blocks:
  249. if not block.is_full and num_lookahead_slots != 0:
  250. new_block_count += 1
  251. if num_lookahead_slots > block.num_empty_slots:
  252. new_block_count += cdiv(
  253. num_lookahead_slots - block.num_empty_slots,
  254. self._block_size)
  255. else:
  256. old_block_set.add(block.block_id)
  257. num_touched_blocks = new_block_count + len(old_block_set)
  258. return num_touched_blocks
  259. def swap_out(self, blocks: List[Block]) -> None:
  260. for block in blocks:
  261. self._free_block_id(block)
  262. def swap_in(self, blocks: List[Block]) -> None:
  263. for block in blocks:
  264. # Here we allocate either immutable or mutable block and then
  265. # extract its block_id. Note that the block object is released
  266. # and the block_id is assigned to "block" to allow reusing the
  267. # existing "block" object
  268. if block.is_full:
  269. tmp_block = self.allocate_immutable_block(
  270. prev_block=block.prev_block, token_ids=block.token_ids)
  271. else:
  272. tmp_block = self.allocate_mutable_block(
  273. prev_block=block.prev_block)
  274. tmp_block.append_token_ids(block.token_ids)
  275. block_id = tmp_block.block_id
  276. tmp_block.block_id = None
  277. self._block_pool.free_block(tmp_block)
  278. block.block_id = block_id # Assign block_id
  279. class NaiveBlock(Block):
  280. """An implementation of the Block class that does not support prefix
  281. caching.
  282. The NaiveBlock class represents a block of token IDs with a fixed size. It
  283. provides methods for appending token IDs to the block and manages copy-on
  284. -write operations when necessary.
  285. Args:
  286. prev_block (Block): The previous block in the sequence.
  287. token_ids (List[int]): The initial token IDs to be stored in the block.
  288. block_size (int): The maximum number of token IDs that can be stored in
  289. the block.
  290. allocator (BlockAllocator): The block allocator associated with this
  291. block.
  292. block_id (Optional[int], optional): The physical block index
  293. of this block. Defaults to None, which means no allocation has been
  294. made.
  295. _cow_target (Optional[Block], optional): The copy-on-write target block.
  296. If not provided, it defaults to self.
  297. """
  298. def __init__(self,
  299. prev_block: Optional[Block],
  300. token_ids: List[int],
  301. block_size: int,
  302. allocator: BlockAllocator,
  303. block_id: Optional[int] = None,
  304. _cow_target: Optional[Block] = None):
  305. self._token_ids: List[int] = []
  306. self._block_size = block_size
  307. self._prev_block = prev_block
  308. self._block_id = block_id
  309. self._allocator = allocator
  310. self._cow_target = _cow_target if _cow_target is not None else self
  311. self._append_token_ids_no_cow(token_ids)
  312. def append_token_ids(self, token_ids: List[int]) -> None:
  313. """Appends the given token IDs to the block and performs a
  314. copy-on-write if necessary.
  315. Args:
  316. token_ids (Optional[List[int]]): The token IDs to be appended
  317. to the block.
  318. """
  319. self._append_token_ids_no_cow(token_ids)
  320. if self._block_id is not None:
  321. self._block_id = (self._allocator.cow_block_if_not_appendable(
  322. self._cow_target))
  323. def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
  324. """Appends the given token IDs to the block
  325. Args:
  326. token_ids (List[int]): The token IDs to be appended to the block.
  327. """
  328. if len(token_ids) == 0:
  329. return
  330. assert len(token_ids) <= self.num_empty_slots
  331. self._token_ids.extend(token_ids)
  332. @property
  333. def computed(self) -> bool:
  334. raise NotImplementedError
  335. @computed.setter
  336. def computed(self, value) -> None:
  337. raise NotImplementedError
  338. @property
  339. def last_accessed(self) -> float:
  340. raise NotImplementedError
  341. @last_accessed.setter
  342. def last_accessed(self, last_accessed_ts: float):
  343. raise NotImplementedError
  344. @property
  345. def block_id(self) -> Optional[int]:
  346. return self._block_id
  347. @block_id.setter
  348. def block_id(self, value: Optional[int]) -> None:
  349. self._block_id = value
  350. @property
  351. def is_full(self) -> bool:
  352. return self.num_empty_slots == 0
  353. @property
  354. def num_empty_slots(self) -> int:
  355. return self._block_size - len(self.token_ids)
  356. @property
  357. def token_ids(self) -> List[int]:
  358. return self._token_ids
  359. @property
  360. def num_tokens_total(self) -> int:
  361. raise NotImplementedError(
  362. "num_tokens_total is not used for naive block")
  363. @property
  364. def block_size(self) -> int:
  365. return self._block_size
  366. @property
  367. def prev_block(self) -> Optional["Block"]:
  368. return self._prev_block
  369. @property
  370. def content_hash(self) -> Optional[int]:
  371. return None