block_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. """A block manager that manages token blocks."""
  2. import enum
  3. from itertools import count, takewhile
  4. from os.path import commonprefix
  5. from typing import Dict, List, Optional, Set, Tuple
  6. from aphrodite.common.block import BlockTable, PhysicalTokenBlock
  7. from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
  8. from aphrodite.common.utils import Device
  9. from aphrodite.processing.evictor import Evictor, EvictionPolicy, make_evictor
  10. class BlockAllocator:
  11. """Manages free physical token blocks for a device.
  12. The allocator maintains a list of free blocks and allocates a block when
  13. requested. When a block is freed, its reference count is decremented. If
  14. the reference count becomes zero, the block is added back to the free list.
  15. """
  16. def __init__(self,
  17. device: Device,
  18. block_size: int,
  19. num_blocks: int,
  20. eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
  21. enable_caching: bool = False) -> None:
  22. self.device = device
  23. self.block_size = block_size
  24. self.num_blocks = num_blocks
  25. self.enable_caching = enable_caching
  26. self.current_num_blocks = 0
  27. self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
  28. # Switch over to FIFO eviction when caching is disabled
  29. if not self.enable_caching:
  30. eviction_policy = EvictionPolicy.FIFO
  31. self.evictor: Evictor = make_evictor(eviction_policy)
  32. self.default_hash_ctr = count()
  33. def allocate_block(self, block_hash: int,
  34. num_hashed_tokens: int) -> PhysicalTokenBlock:
  35. if self.current_num_blocks == self.num_blocks:
  36. block = self.evictor.evict()
  37. block.block_hash = block_hash
  38. block.num_hashed_tokens = num_hashed_tokens
  39. return block
  40. block = PhysicalTokenBlock(device=self.device,
  41. block_number=self.current_num_blocks,
  42. block_size=self.block_size,
  43. block_hash=block_hash,
  44. num_hashed_tokens=num_hashed_tokens)
  45. self.current_num_blocks += 1
  46. return block
  47. def allocate(self,
  48. block_hash: Optional[int] = None,
  49. num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
  50. # If caching is disabled, just allocate a new block and return it
  51. if not self.enable_caching:
  52. block = self.allocate_block(next(self.default_hash_ctr),
  53. num_hashed_tokens)
  54. block.ref_count += 1
  55. return block
  56. if block_hash is None:
  57. block_hash = next(self.default_hash_ctr)
  58. if block_hash in self.evictor:
  59. assert block_hash not in self.cached_blocks
  60. block = self.evictor.remove(block_hash)
  61. assert block.ref_count == 0
  62. self.cached_blocks[block_hash] = block
  63. block.ref_count += 1
  64. assert block.block_hash == block_hash
  65. return block
  66. if block_hash not in self.cached_blocks:
  67. self.cached_blocks[block_hash] = self.allocate_block(
  68. block_hash, num_hashed_tokens)
  69. block = self.cached_blocks[block_hash]
  70. assert block.block_hash == block_hash
  71. block.ref_count += 1
  72. return block
  73. def free(self, block: PhysicalTokenBlock) -> None:
  74. if block.ref_count == 0:
  75. raise ValueError(f"Double free! {block} is already freed.")
  76. block.ref_count -= 1
  77. if block.ref_count == 0:
  78. assert block.block_hash not in self.evictor
  79. self.evictor.add(block)
  80. # If caching is enabled, remove the block from the cached_blocks
  81. if self.enable_caching:
  82. del self.cached_blocks[block.block_hash]
  83. def get_num_free_blocks(self) -> int:
  84. return (self.num_blocks - self.current_num_blocks +
  85. self.evictor.num_blocks)
  86. def contains_block(self, block_hash: int) -> bool:
  87. return block_hash in self.cached_blocks or block_hash in self.evictor
  88. def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
  89. # If caching is enabled, update the hash of block and the cached_blocks
  90. # dictionary.
  91. if self.enable_caching:
  92. assert not self.contains_block(block_hash)
  93. old_hash = block.block_hash
  94. block.block_hash = block_hash
  95. del self.cached_blocks[old_hash]
  96. self.cached_blocks[block_hash] = block
  97. class AllocStatus(enum.Enum):
  98. """Result for BlockSpaceManager.can_allocate
  99. 1. Ok: seq_group can be allocated now.
  100. 2. Later: seq_group cannot be allocated.
  101. The capacity of allocator is larger than seq_group required.
  102. 3. Never: seq_group can never be allocated.
  103. The seq_group is too large to allocated in GPU.
  104. """
  105. OK = enum.auto()
  106. LATER = enum.auto()
  107. NEVER = enum.auto()
  108. class BlockSpaceManager:
  109. """Manages the mapping between logical and physical token blocks."""
  110. def __init__(
  111. self,
  112. block_size: int,
  113. num_gpu_blocks: int,
  114. num_cpu_blocks: int,
  115. watermark: float = 0.01,
  116. sliding_window: Optional[int] = None,
  117. enable_caching: bool = False,
  118. ) -> None:
  119. self.block_size = block_size
  120. self.num_total_gpu_blocks = num_gpu_blocks
  121. self.num_total_cpu_blocks = num_cpu_blocks
  122. self.block_sliding_window = None
  123. if sliding_window is not None:
  124. assert sliding_window % block_size == 0, (sliding_window,
  125. block_size)
  126. self.block_sliding_window = sliding_window // block_size
  127. self.watermark = watermark
  128. assert watermark >= 0.0
  129. self.enable_caching = enable_caching
  130. self.watermark_blocks = int(watermark * num_gpu_blocks)
  131. self.gpu_allocator = BlockAllocator(Device.GPU,
  132. block_size,
  133. num_gpu_blocks,
  134. enable_caching=enable_caching)
  135. self.cpu_allocator = BlockAllocator(Device.CPU,
  136. block_size,
  137. num_cpu_blocks,
  138. enable_caching=enable_caching)
  139. # Mapping: seq_id -> BlockTable.
  140. self.block_tables: Dict[int, BlockTable] = {}
  141. def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
  142. # FIXME: Here we assume that all sequences in the group share
  143. # the same prompt. This may not be true for preempted sequences.
  144. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  145. num_required_blocks = len(seq.logical_token_blocks)
  146. if self.block_sliding_window is not None:
  147. num_required_blocks = min(num_required_blocks,
  148. self.block_sliding_window)
  149. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  150. # Use watermark to avoid frequent cache eviction.
  151. if (self.num_total_gpu_blocks - num_required_blocks <
  152. self.watermark_blocks):
  153. return AllocStatus.NEVER
  154. if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
  155. return AllocStatus.OK
  156. else:
  157. return AllocStatus.LATER
  158. def allocate(self, seq_group: SequenceGroup) -> None:
  159. # NOTE: Here we assume that all sequences in the group have the same
  160. # prompt.
  161. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
  162. # Allocate new physical token blocks that will store the prompt tokens.
  163. num_prompt_blocks = len(seq.logical_token_blocks)
  164. block_table: BlockTable = []
  165. for logical_idx in range(num_prompt_blocks):
  166. if (self.block_sliding_window is not None
  167. and logical_idx >= self.block_sliding_window):
  168. block = block_table[logical_idx % self.block_sliding_window]
  169. else:
  170. block = self.gpu_allocator.allocate(
  171. seq.hash_of_block(logical_idx),
  172. seq.num_hashed_tokens_of_block(logical_idx))
  173. block_table.append(block)
  174. # Assign the block table for each sequence.
  175. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  176. self.block_tables[seq.seq_id] = block_table.copy()
  177. def can_append_slot(self, seq_group: SequenceGroup) -> bool:
  178. # Simple heuristic: If there is at least one free block
  179. # for each sequence, we can append.
  180. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
  181. num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
  182. return num_seqs <= num_free_gpu_blocks
  183. def _promote_last_block(
  184. self,
  185. seq: Sequence,
  186. last_block: PhysicalTokenBlock,
  187. ) -> PhysicalTokenBlock:
  188. # Compute a new hash for the block so that it can be shared by other
  189. # Sequences
  190. new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
  191. # if new_hash is already in the cached table, then free last_block and
  192. # return the cached version
  193. if self.gpu_allocator.contains_block(new_hash):
  194. self.gpu_allocator.free(last_block)
  195. return self.gpu_allocator.allocate(new_hash)
  196. else:
  197. self.gpu_allocator.update_hash(new_hash, last_block)
  198. return last_block
  199. def _is_last_block_full(
  200. self,
  201. seq: Sequence,
  202. ) -> bool:
  203. token_ids_len = len(seq.data.get_token_ids())
  204. return token_ids_len > 0 and token_ids_len % seq.block_size == 0
  205. def _maybe_promote_last_block(
  206. self,
  207. seq: Sequence,
  208. last_block: PhysicalTokenBlock,
  209. ) -> PhysicalTokenBlock:
  210. if self._is_last_block_full(seq):
  211. return self._promote_last_block(seq, last_block)
  212. else:
  213. return last_block
  214. def _allocate_last_physical_block(
  215. self,
  216. seq: Sequence,
  217. ) -> PhysicalTokenBlock:
  218. block_hash: Optional[int] = None
  219. if self._is_last_block_full(seq):
  220. block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
  221. num_hashed_tokens = seq.num_hashed_tokens_of_block(
  222. len(seq.logical_token_blocks) - 1)
  223. new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
  224. if block_hash is None:
  225. assert new_block.ref_count == 1
  226. return new_block
  227. def append_slot(
  228. self,
  229. seq: Sequence,
  230. ) -> Optional[Tuple[int, int]]:
  231. """Allocate a physical slot for a new token."""
  232. logical_blocks = seq.logical_token_blocks
  233. block_table = self.block_tables[seq.seq_id]
  234. # If we need to allocate a new physical block
  235. if len(block_table) < len(logical_blocks):
  236. # Currently this code only supports adding one physical block
  237. assert len(block_table) == len(logical_blocks) - 1
  238. if (self.block_sliding_window
  239. and len(block_table) >= self.block_sliding_window):
  240. # reuse a block
  241. block_table.append(block_table[len(block_table) %
  242. self.block_sliding_window])
  243. else:
  244. # The sequence has a new logical block.
  245. # Allocate a new physical block.
  246. new_block = self._allocate_last_physical_block(seq)
  247. block_table.append(new_block)
  248. return None
  249. # We want to append the token to the last physical block.
  250. last_block = block_table[-1]
  251. assert last_block.device == Device.GPU
  252. if last_block.ref_count == 1:
  253. # Not shared with other sequences. Appendable.
  254. # If the last block is now complete, promote it to a full block so
  255. # that it can be shared
  256. new_block = self._maybe_promote_last_block(seq, last_block)
  257. block_table[-1] = new_block
  258. return None
  259. else:
  260. # The last block is shared with other sequences.
  261. # Copy on Write: Allocate a new block and copy the tokens.
  262. new_block = self._allocate_last_physical_block(seq)
  263. block_table[-1] = new_block
  264. self.gpu_allocator.free(last_block)
  265. return last_block.block_number, new_block.block_number
  266. def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  267. # NOTE: fork does not allocate a new physical block.
  268. # Thus, it is always safe from OOM.
  269. src_block_table = self.block_tables[parent_seq.seq_id]
  270. self.block_tables[child_seq.seq_id] = src_block_table.copy()
  271. for block in src_block_table:
  272. block.ref_count += 1
  273. def _get_physical_blocks(
  274. self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
  275. # NOTE: Here, we assume that the physical blocks are only shared by
  276. # the sequences in the same group.
  277. blocks: Set[PhysicalTokenBlock] = set()
  278. for seq in seq_group.get_seqs():
  279. if seq.is_finished():
  280. continue
  281. blocks.update(self.block_tables[seq.seq_id])
  282. return list(blocks)
  283. def can_swap_in(self, seq_group: SequenceGroup) -> bool:
  284. blocks = self._get_physical_blocks(seq_group)
  285. num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
  286. num_free_blocks = self.gpu_allocator.get_num_free_blocks()
  287. # NOTE: Conservatively, we assume that every sequence will allocate
  288. # at least one free block right after the swap-in.
  289. # NOTE: This should match the logic in can_append_slot().
  290. num_required_blocks = len(blocks) + num_swapped_seqs
  291. return num_free_blocks - num_required_blocks >= self.watermark_blocks
  292. def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
  293. # CPU block -> GPU block.
  294. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  295. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  296. new_block_table: BlockTable = []
  297. block_table = self.block_tables[seq.seq_id]
  298. for cpu_block in block_table:
  299. if cpu_block in mapping:
  300. gpu_block = mapping[cpu_block]
  301. gpu_block.ref_count += 1
  302. else:
  303. gpu_block = self.gpu_allocator.allocate(
  304. cpu_block.block_hash, cpu_block.num_hashed_tokens)
  305. mapping[cpu_block] = gpu_block
  306. new_block_table.append(gpu_block)
  307. # Free the CPU block swapped in to GPU.
  308. self.cpu_allocator.free(cpu_block)
  309. self.block_tables[seq.seq_id] = new_block_table
  310. block_number_mapping = {
  311. cpu_block.block_number: gpu_block.block_number
  312. for cpu_block, gpu_block in mapping.items()
  313. }
  314. return block_number_mapping
  315. def can_swap_out(self, seq_group: SequenceGroup) -> bool:
  316. blocks = self._get_physical_blocks(seq_group)
  317. return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
  318. def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
  319. # GPU block -> CPU block.
  320. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
  321. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  322. new_block_table: BlockTable = []
  323. block_table = self.block_tables[seq.seq_id]
  324. for gpu_block in block_table:
  325. if gpu_block in mapping:
  326. cpu_block = mapping[gpu_block]
  327. cpu_block.ref_count += 1
  328. else:
  329. cpu_block = self.cpu_allocator.allocate(
  330. gpu_block.block_hash, gpu_block.num_hashed_tokens)
  331. mapping[gpu_block] = cpu_block
  332. new_block_table.append(cpu_block)
  333. # Free the GPU block swapped out to CPU.
  334. self.gpu_allocator.free(gpu_block)
  335. self.block_tables[seq.seq_id] = new_block_table
  336. block_number_mapping = {
  337. gpu_block.block_number: cpu_block.block_number
  338. for gpu_block, cpu_block in mapping.items()
  339. }
  340. return block_number_mapping
  341. def _free_block_table(self, block_table: BlockTable) -> None:
  342. for block in set(block_table):
  343. if block.device == Device.GPU:
  344. self.gpu_allocator.free(block)
  345. else:
  346. self.cpu_allocator.free(block)
  347. def free(self, seq: Sequence) -> None:
  348. if seq.seq_id not in self.block_tables:
  349. # Already freed or haven't been scheduled yet.
  350. return
  351. block_table = self.block_tables[seq.seq_id]
  352. self._free_block_table(block_table)
  353. del self.block_tables[seq.seq_id]
  354. def reset(self) -> None:
  355. for block_table in self.block_tables.values():
  356. self._free_block_table(block_table)
  357. self.block_tables.clear()
  358. def get_block_table(self, seq: Sequence) -> List[int]:
  359. block_table = self.block_tables[seq.seq_id]
  360. return [block.block_number for block in block_table]
  361. def get_num_free_gpu_blocks(self) -> int:
  362. return self.gpu_allocator.get_num_free_blocks()
  363. def get_num_free_cpu_blocks(self) -> int:
  364. return self.cpu_allocator.get_num_free_blocks()
  365. def access_all_blocks_in_seq(
  366. self,
  367. seq: Sequence,
  368. access_time: float,
  369. ) -> None:
  370. block_table = self.block_tables[seq.seq_id]
  371. for block in block_table:
  372. block.last_accessed = access_time
  373. def compute_full_blocks_in_seq(self, seq: Sequence):
  374. if seq.seq_id not in self.block_tables:
  375. return
  376. max_full_block = seq.get_len() // self.block_size - 1
  377. block_table = self.block_tables[seq.seq_id]
  378. if max_full_block == -1:
  379. return
  380. for i in reversed(range(max_full_block)):
  381. if block_table[i].computed:
  382. break
  383. block_table[i].computed = True
  384. def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
  385. if seq.seq_id not in self.block_tables:
  386. return []
  387. block_table = self.block_tables[seq.seq_id]
  388. # NOTE We exclude the last block to avoid the case where the entire
  389. # prompt is cached. This would cause erroneous behavior in model
  390. # runner.
  391. return [
  392. b.block_number
  393. for b in takewhile(lambda b: b.computed, block_table[:-1])
  394. ]
  395. def get_common_computed_block_ids(self,
  396. seq_group: SequenceGroup) -> List[int]:
  397. # Can return non-empty result only with prefix caching enabled.
  398. if not self.enable_caching:
  399. return []
  400. ids_list = [
  401. self.get_all_computed_blocks(seq)
  402. for seq in iter(seq_group.seqs_dict.values())
  403. ]
  404. return commonprefix([ids for ids in ids_list if ids != []])
  405. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  406. if self.enable_caching:
  407. for seq in seq_group.seqs_dict.values():
  408. self.compute_full_blocks_in_seq(seq)