cpu_gpu_block_allocator.py 15 KB


  1. from typing import Dict, FrozenSet, List, Optional, Tuple
  2. from aphrodite.common.utils import Device
  3. from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
  4. BlockId,
  5. DeviceAwareBlockAllocator)
  6. from aphrodite.processing.block.naive_block import (NaiveBlock,
  7. NaiveBlockAllocator)
  8. from aphrodite.processing.block.prefix_caching_block import (
  9. PrefixCachingBlockAllocator)
  10. class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
  11. """A block allocator that can allocate blocks on both CPU and GPU memory.
  12. This class implements the `DeviceAwareBlockAllocator` interface and provides
  13. functionality for allocating and managing blocks of memory on both CPU and
  14. GPU devices.
  15. The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
  16. blocks, and allows for allocation, deallocation, forking, and swapping of
  17. blocks across these memory pools.
  18. """
  19. @staticmethod
  20. def create(
  21. allocator_type: str,
  22. num_gpu_blocks: int,
  23. num_cpu_blocks: int,
  24. block_size: int,
  25. ) -> DeviceAwareBlockAllocator:
  26. """Creates a CpuGpuBlockAllocator instance with the specified
  27. configuration.
  28. This static method creates and returns a CpuGpuBlockAllocator instance
  29. based on the provided parameters. It initializes the CPU and GPU block
  30. allocators with the specified number of blocks, block size, and
  31. allocator type.
  32. Args:
  33. allocator_type (str): The type of block allocator to use for CPU
  34. and GPU blocks. Currently supported values are "naive" and
  35. "prefix_caching".
  36. num_gpu_blocks (int): The number of blocks to allocate for GPU
  37. memory.
  38. num_cpu_blocks (int): The number of blocks to allocate for CPU
  39. memory.
  40. block_size (int): The size of each block in number of tokens.
  41. Returns:
  42. DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
  43. specified configuration.
  44. Notes:
  45. - The block IDs are assigned contiguously, with GPU block IDs coming
  46. before CPU block IDs.
  47. """
  48. block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
  49. gpu_block_ids = block_ids[:num_gpu_blocks]
  50. cpu_block_ids = block_ids[num_gpu_blocks:]
  51. if allocator_type == "naive":
  52. gpu_allocator: BlockAllocator = NaiveBlockAllocator(
  53. create_block=NaiveBlock, # type: ignore
  54. num_blocks=num_gpu_blocks,
  55. block_size=block_size,
  56. block_ids=gpu_block_ids,
  57. )
  58. cpu_allocator: BlockAllocator = NaiveBlockAllocator(
  59. create_block=NaiveBlock, # type: ignore
  60. num_blocks=num_cpu_blocks,
  61. block_size=block_size,
  62. block_ids=cpu_block_ids,
  63. )
  64. elif allocator_type == "prefix_caching":
  65. gpu_allocator = PrefixCachingBlockAllocator(
  66. num_blocks=num_gpu_blocks,
  67. block_size=block_size,
  68. block_ids=gpu_block_ids,
  69. )
  70. cpu_allocator = PrefixCachingBlockAllocator(
  71. num_blocks=num_cpu_blocks,
  72. block_size=block_size,
  73. block_ids=cpu_block_ids,
  74. )
  75. else:
  76. raise ValueError(f"Unknown allocator type {allocator_type=}")
  77. return CpuGpuBlockAllocator(
  78. cpu_block_allocator=cpu_allocator,
  79. gpu_block_allocator=gpu_allocator,
  80. )
  81. def __init__(self, cpu_block_allocator: BlockAllocator,
  82. gpu_block_allocator: BlockAllocator):
  83. assert not (
  84. cpu_block_allocator.all_block_ids
  85. & gpu_block_allocator.all_block_ids
  86. ), "cpu and gpu block allocators can't have intersection of block ids"
  87. self._allocators = {
  88. Device.CPU: cpu_block_allocator,
  89. Device.GPU: gpu_block_allocator,
  90. }
  91. self._swap_mapping: Dict[int, int] = {}
  92. self._null_block: Optional[Block] = None
  93. self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
  94. for _, allocator in self._allocators.items():
  95. for block_id in allocator.all_block_ids:
  96. self._block_ids_to_allocator[block_id] = allocator
  97. def allocate_or_get_null_block(self) -> Block:
  98. if self._null_block is None:
  99. self._null_block = NullBlock(
  100. self.allocate_mutable_block(None, Device.GPU))
  101. return self._null_block
  102. def allocate_mutable_block(self, prev_block: Optional[Block],
  103. device: Device) -> Block:
  104. """Allocates a new mutable block on the specified device.
  105. Args:
  106. prev_block (Optional[Block]): The previous block to in the sequence.
  107. Used for prefix hashing.
  108. device (Device): The device on which to allocate the new block.
  109. Returns:
  110. Block: The newly allocated mutable block.
  111. """
  112. return self._allocators[device].allocate_mutable_block(prev_block)
  113. def allocate_immutable_blocks(self, prev_block: Optional[Block],
  114. block_token_ids: List[List[int]],
  115. device: Optional[Device]) -> List[Block]:
  116. """Allocates a new group of immutable blocks with the provided block
  117. token IDs on the specified device.
  118. Args:
  119. prev_block (Optional[Block]): The previous block in the sequence.
  120. Used for prefix hashing.
  121. block_token_ids (List[int]): The list of block token IDs to be
  122. stored in the new blocks.
  123. device (Device): The device on which to allocate the new block.
  124. Returns:
  125. List[Block]: The newly allocated list of immutable blocks
  126. containing the provided block token IDs.
  127. """
  128. return self._allocators[device].allocate_immutable_blocks(
  129. prev_block, block_token_ids)
  130. def allocate_immutable_block(self, prev_block: Optional[Block],
  131. token_ids: List[int],
  132. device: Device) -> Block:
  133. """Allocates a new immutable block with the provided token IDs on the
  134. specified device.
  135. Args:
  136. prev_block (Optional[Block]): The previous block in the sequence.
  137. Used for prefix hashing.
  138. token_ids (List[int]): The list of token IDs to be stored in the new
  139. block.
  140. device (Device): The device on which to allocate the new block.
  141. Returns:
  142. Block: The newly allocated immutable block containing the provided
  143. token IDs.
  144. """
  145. return self._allocators[device].allocate_immutable_block(
  146. prev_block, token_ids)
  147. def free(self, block: Block) -> None:
  148. """Frees the memory occupied by the given block.
  149. Args:
  150. block (Block): The block to be freed.
  151. """
  152. # Null block should never be freed
  153. if isinstance(block, NullBlock):
  154. return
  155. block_id = block.block_id
  156. assert block_id is not None
  157. allocator = self._block_ids_to_allocator[block_id]
  158. allocator.free(block)
  159. def fork(self, last_block: Block) -> List[Block]:
  160. """Creates a new sequence of blocks that shares the same underlying
  161. memory as the original sequence.
  162. Args:
  163. last_block (Block): The last block in the original sequence.
  164. Returns:
  165. List[Block]: A new list of blocks that shares the same memory as the
  166. original sequence.
  167. """
  168. # do not attempt to fork the null block
  169. assert not isinstance(last_block, NullBlock)
  170. block_id = last_block.block_id
  171. assert block_id is not None
  172. allocator = self._block_ids_to_allocator[block_id]
  173. return allocator.fork(last_block)
  174. def get_num_free_blocks(self, device: Device) -> int:
  175. """Returns the number of free blocks available on the specified device.
  176. Args:
  177. device (Device): The device for which to query the number of free
  178. blocks. AssertionError is raised if None is passed.
  179. Returns:
  180. int: The number of free blocks available on the specified device.
  181. """
  182. return self._allocators[device].get_num_free_blocks()
  183. def get_num_total_blocks(self, device: Device) -> int:
  184. return self._allocators[device].get_num_total_blocks()
  185. def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
  186. """Returns the zero-offset block id on certain device given the
  187. absolute block id.
  188. Args:
  189. device (Device): The device for which to query relative block id.
  190. absolute_id (int): The absolute block id for the block in
  191. whole allocator.
  192. Returns:
  193. int: The zero-offset block id on certain device.
  194. """
  195. return self._allocators[device].get_physical_block_id(absolute_id)
  196. def swap(self, blocks: List[Block], src_device: Device,
  197. dst_device: Device) -> Dict[int, int]:
  198. """Execute the swap for the given blocks from source_device
  199. on to dest_device, save the current swap mapping and append
  200. them to the accumulated `self._swap_mapping` for each
  201. scheduling move.
  202. Args:
  203. blocks: List of blocks to be swapped.
  204. src_device (Device): Device to swap the 'blocks' from.
  205. dst_device (Device): Device to swap the 'blocks' to.
  206. Returns:
  207. Dict[int, int]: Swap mapping from source_device
  208. on to dest_device.
  209. """
  210. src_block_ids = [block.block_id for block in blocks]
  211. self._allocators[src_device].swap_out(blocks)
  212. self._allocators[dst_device].swap_in(blocks)
  213. dst_block_ids = [block.block_id for block in blocks]
  214. current_swap_mapping: Dict[int, int] = {}
  215. for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
  216. if src_block_id is not None and dst_block_id is not None:
  217. self._swap_mapping[src_block_id] = dst_block_id
  218. current_swap_mapping[src_block_id] = dst_block_id
  219. return current_swap_mapping
  220. def get_num_blocks_touched(self,
  221. blocks: List[Block],
  222. device: Device,
  223. num_lookahead_slots: int = 0) -> int:
  224. """Returns the number of blocks that will be touched by
  225. swapping in/out the given blocks on to the 'device'.
  226. Args:
  227. blocks: List of blocks to be swapped.
  228. device (Device): Device to swap the 'blocks' on.
  229. num_lookahead_slots (int): Number of lookahead slots used in
  230. speculative decoding, default to 0.
  231. Returns:
  232. int: the number of blocks that will be touched by
  233. swapping in/out the given blocks on to the 'device'.
  234. """
  235. return self._allocators[device].get_num_blocks_touched(
  236. blocks, num_lookahead_slots)
  237. def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
  238. """Clears the copy-on-write (CoW) state and returns the mapping of
  239. source to destination block IDs.
  240. Returns:
  241. List[Tuple[int, int]]: A list mapping source block IDs to
  242. destination block IDs.
  243. """
  244. # CoW only supported on GPU
  245. device = Device.GPU
  246. return self._allocators[device].clear_copy_on_writes()
  247. def mark_blocks_as_accessed(self, block_ids: List[int],
  248. now: float) -> None:
  249. """Mark blocks as accessed, only use for prefix caching."""
  250. # Prefix caching only supported on GPU.
  251. device = Device.GPU
  252. return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
  253. def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
  254. """Mark blocks as accessed, only use for prefix caching."""
  255. # Prefix caching only supported on GPU.
  256. device = Device.GPU
  257. return self._allocators[device].mark_blocks_as_computed(block_ids)
  258. def get_computed_block_ids(self, prev_computed_block_ids: List[int],
  259. block_ids: List[int],
  260. skip_last_block_id: bool) -> List[int]:
  261. # Prefix caching only supported on GPU.
  262. device = Device.GPU
  263. return self._allocators[device].get_computed_block_ids(
  264. prev_computed_block_ids, block_ids, skip_last_block_id)
  265. def get_common_computed_block_ids(
  266. self, computed_seq_block_ids: List[List[int]]) -> List[int]:
  267. # Prefix caching only supported on GPU.
  268. device = Device.GPU
  269. return self._allocators[device].get_common_computed_block_ids(
  270. computed_seq_block_ids)
  271. @property
  272. def all_block_ids(self) -> FrozenSet[int]:
  273. return frozenset(self._block_ids_to_allocator.keys())
  274. def get_prefix_cache_hit_rate(self, device: Device) -> float:
  275. """Prefix cache hit rate. -1 means not supported or disabled."""
  276. assert device in self._allocators
  277. return self._allocators[device].get_prefix_cache_hit_rate()
  278. def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
  279. """Returns and clears the mapping of source to destination block IDs.
  280. Will be called after every swapping operations for now, and after every
  281. schedule when BlockManagerV2 become default. Currently not useful.
  282. Returns:
  283. List[Tuple[int, int]]: A mapping of source to destination block IDs.
  284. """
  285. mapping = self._swap_mapping.copy()
  286. self._swap_mapping.clear()
  287. return list(mapping.items())
  288. class NullBlock(Block):
  289. """
  290. Null blocks are used as a placeholders for KV cache blocks that have
  291. been dropped due to sliding window.
  292. This implementation just wraps an ordinary block and prevents it from
  293. being modified. It also allows for testing if a block is NullBlock
  294. via isinstance().
  295. """
  296. def __init__(self, proxy: Block):
  297. super().__init__()
  298. self._proxy = proxy
  299. def append_token_ids(self, token_ids: List[BlockId]):
  300. raise ValueError("null block should not be modified")
  301. @property
  302. def block_id(self):
  303. return self._proxy.block_id
  304. @block_id.setter
  305. def block_id(self, value: Optional[BlockId]):
  306. raise ValueError("null block should not be modified")
  307. @property
  308. def token_ids(self) -> List[BlockId]:
  309. return self._proxy.token_ids
  310. @property
  311. def num_tokens_total(self) -> int:
  312. raise NotImplementedError(
  313. "num_tokens_total is not used for null block")
  314. @property
  315. def num_empty_slots(self) -> BlockId:
  316. return self._proxy.num_empty_slots
  317. @property
  318. def is_full(self):
  319. return self._proxy.is_full
  320. @property
  321. def prev_block(self):
  322. return self._proxy.prev_block
  323. @property
  324. def computed(self):
  325. return self._proxy.computed
  326. @computed.setter
  327. def computed(self, value):
  328. self._proxy.computed = value
  329. @property
  330. def last_accessed(self) -> float:
  331. return self._proxy.last_accessed
  332. @last_accessed.setter
  333. def last_accessed(self, last_accessed_ts: float):
  334. self._proxy.last_accessed = last_accessed_ts
  335. @property
  336. def content_hash(self):
  337. return self._proxy.content_hash