scheduler.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. import enum
  2. import time
  3. from typing import Dict, Iterable, List, Optional, Tuple, Union
  4. from aphrodite.common.config import CacheConfig, SchedulerConfig
  5. from aphrodite.processing.block_manager import BlockSpaceManager
  6. from aphrodite.processing.policy import PolicyFactory
  7. from aphrodite.common.logger import init_logger
  8. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  9. SequenceGroupMetadata, SequenceStatus)
  10. logger = init_logger(__name__)
  11. class PreemptionMode(enum.Enum):
  12. """Preemption modes.
  13. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  14. and swap them back in when the sequences are resumed.
  15. 2. Recomputation: Discard the blocks of the preempted sequences and
  16. recompute them when the sequences are resumed, treating the sequences as
  17. new prompts.
  18. """
  19. SWAP = enum.auto()
  20. RECOMPUTE = enum.auto()
  21. class SchedulerOutputs:
  22. def __init__(
  23. self,
  24. scheduled_seq_groups: List[SequenceGroup],
  25. prompt_run: bool,
  26. num_batched_tokens: int,
  27. blocks_to_swap_in: Dict[int, int],
  28. blocks_to_swap_out: Dict[int, int],
  29. blocks_to_copy: Dict[int, List[int]],
  30. ignored_seq_groups: List[SequenceGroup],
  31. ) -> None:
  32. self.scheduled_seq_groups = scheduled_seq_groups
  33. self.prompt_run = prompt_run
  34. self.num_batched_tokens = num_batched_tokens
  35. self.blocks_to_swap_in = blocks_to_swap_in
  36. self.blocks_to_swap_out = blocks_to_swap_out
  37. self.blocks_to_copy = blocks_to_copy
  38. # Swap in and swap out should never happen at the same time.
  39. assert not (blocks_to_swap_in and blocks_to_swap_out)
  40. self.ignored_seq_groups = ignored_seq_groups
  41. def is_empty(self) -> bool:
  42. # NOTE: We do not consider the ignored sequence groups.
  43. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  44. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  45. class Scheduler:
  46. def __init__(
  47. self,
  48. scheduler_config: SchedulerConfig,
  49. cache_config: CacheConfig,
  50. ) -> None:
  51. self.scheduler_config = scheduler_config
  52. self.cache_config = cache_config
  53. self.prompt_limit = min(self.scheduler_config.max_model_len,
  54. self.scheduler_config.max_num_batched_tokens)
  55. # Instantiate the scheduling policy.
  56. self.policy = PolicyFactory.get_policy(policy_name="fcfs")
  57. # Create the block space manager.
  58. self.block_manager = BlockSpaceManager(
  59. block_size=self.cache_config.block_size,
  60. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  61. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  62. sliding_window=self.cache_config.sliding_window)
  63. # TODO: Use deque instead of list for better performance.
  64. # Sequence groups in the WAITING state.
  65. self.waiting: List[SequenceGroup] = []
  66. # Sequence groups in the RUNNING state.
  67. self.running: List[SequenceGroup] = []
  68. # Sequence groups in the SWAPPED state.
  69. self.swapped: List[SequenceGroup] = []
  70. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  71. # Add sequence groups to the waiting queue.
  72. self.waiting.append(seq_group)
  73. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  74. if isinstance(request_id, str):
  75. request_id = (request_id, )
  76. request_ids = set(request_id)
  77. for state_queue in [self.waiting, self.running, self.swapped]:
  78. # We need to reverse the list as we are removing elements
  79. # from it as we iterate over it. If we don't do it,
  80. # indices will get messed up and we will skip over elements.
  81. for seq_group in reversed(state_queue):
  82. if seq_group.request_id in request_ids:
  83. # Remove the sequence group from the state queue.
  84. state_queue.remove(seq_group)
  85. for seq in seq_group.get_seqs():
  86. if seq.is_finished():
  87. continue
  88. seq.status = SequenceStatus.FINISHED_ABORTED
  89. self.free_seq(seq)
  90. request_ids.remove(seq_group.request_id)
  91. if not request_ids:
  92. return
  93. def has_unfinished_seqs(self) -> bool:
  94. return self.waiting or self.running or self.swapped
  95. def get_num_unfinished_seq_groups(self) -> int:
  96. return len(self.waiting) + len(self.running) + len(self.swapped)
  97. def _schedule(self) -> SchedulerOutputs:
  98. # Blocks that need to be swaped or copied before model execution.
  99. blocks_to_swap_in: Dict[int, int] = {}
  100. blocks_to_swap_out: Dict[int, int] = {}
  101. blocks_to_copy: Dict[int, List[int]] = {}
  102. # Fix the current time.
  103. now = time.monotonic()
  104. # Join waiting sequences if possible.
  105. if not self.swapped:
  106. ignored_seq_groups: List[SequenceGroup] = []
  107. scheduled: List[SequenceGroup] = []
  108. # The total number of sequences on the fly, including the
  109. # requests in the generation phase.
  110. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  111. for seq_group in self.running)
  112. seq_lens: List[int] = []
  113. # Optimization: We do not sort the waiting queue since the preempted
  114. # sequence groups are added to the front and the new sequence groups
  115. # are added to the back.
  116. while self.waiting:
  117. seq_group = self.waiting[0]
  118. assert seq_group.num_seqs() == 1, (
  119. "Waiting sequence group should have only one prompt "
  120. "sequence.")
  121. num_prompt_tokens = seq_group.get_seqs()[0].get_len()
  122. if num_prompt_tokens > self.prompt_limit:
  123. logger.warning(
  124. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  125. f" and exceeds limit of {self.prompt_limit}")
  126. for seq in seq_group.get_seqs():
  127. seq.status = SequenceStatus.FINISHED_IGNORED
  128. ignored_seq_groups.append(seq_group)
  129. self.waiting.pop(0)
  130. continue
  131. # If the sequence group cannot be allocated, stop.
  132. if not self.block_manager.can_allocate(seq_group):
  133. break
  134. # If the number of batched tokens exceeds the limit, stop.
  135. new_seq_lens = seq_lens + [num_prompt_tokens]
  136. num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
  137. if (num_batched_tokens >
  138. self.scheduler_config.max_num_batched_tokens):
  139. break
  140. # The total number of sequences in the RUNNING state should not
  141. # exceed the maximum number of sequences.
  142. num_new_seqs = seq_group.get_max_num_running_seqs()
  143. if (num_curr_seqs + num_new_seqs >
  144. self.scheduler_config.max_num_seqs):
  145. break
  146. num_paddings = num_batched_tokens - sum(new_seq_lens)
  147. if num_paddings > self.scheduler_config.max_paddings:
  148. break
  149. seq_lens = new_seq_lens
  150. seq_group = self.waiting.pop(0)
  151. self._allocate(seq_group)
  152. self.running.append(seq_group)
  153. num_curr_seqs += num_new_seqs
  154. scheduled.append(seq_group)
  155. if scheduled or ignored_seq_groups:
  156. scheduler_outputs = SchedulerOutputs(
  157. scheduled_seq_groups=scheduled,
  158. prompt_run=True,
  159. num_batched_tokens=len(seq_lens) * max(seq_lens),
  160. blocks_to_swap_in=blocks_to_swap_in,
  161. blocks_to_swap_out=blocks_to_swap_out,
  162. blocks_to_copy=blocks_to_copy,
  163. ignored_seq_groups=ignored_seq_groups,
  164. )
  165. return scheduler_outputs
  166. # NOTE: Preemption happens only when there is no available slot
  167. # to keep all the sequence groups in the RUNNING state.
  168. # In this case, the policy is responsible for deciding which sequence
  169. # groups to preempt.
  170. self.running = self.policy.sort_by_priority(now, self.running)
  171. # Reserve new token slots for the running sequence groups.
  172. running: List[SequenceGroup] = []
  173. preempted: List[SequenceGroup] = []
  174. while self.running:
  175. seq_group = self.running.pop(0)
  176. while not self.block_manager.can_append_slot(seq_group):
  177. if self.running:
  178. # Preempt the lowest-priority sequence groups.
  179. victim_seq_group = self.running.pop(-1)
  180. self._preempt(victim_seq_group, blocks_to_swap_out)
  181. preempted.append(victim_seq_group)
  182. else:
  183. # No other sequence groups can be preempted.
  184. # Preempt the current sequence group.
  185. self._preempt(seq_group, blocks_to_swap_out)
  186. preempted.append(seq_group)
  187. break
  188. else:
  189. # Append new slots to the sequence group.
  190. self._append_slot(seq_group, blocks_to_copy)
  191. running.append(seq_group)
  192. self.running = running
  193. # Swap in the sequence groups in the SWAPPED state if possible.
  194. self.swapped = self.policy.sort_by_priority(now, self.swapped)
  195. if not preempted:
  196. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  197. for seq_group in self.running)
  198. while self.swapped:
  199. seq_group = self.swapped[0]
  200. # If the sequence group cannot be swapped in, stop.
  201. if not self.block_manager.can_swap_in(seq_group):
  202. break
  203. # The total number of sequences in the RUNNING state should not
  204. # exceed the maximum number of sequences.
  205. num_new_seqs = seq_group.get_max_num_running_seqs()
  206. if (num_curr_seqs + num_new_seqs >
  207. self.scheduler_config.max_num_seqs):
  208. break
  209. seq_group = self.swapped.pop(0)
  210. self._swap_in(seq_group, blocks_to_swap_in)
  211. self._append_slot(seq_group, blocks_to_copy)
  212. num_curr_seqs += num_new_seqs
  213. self.running.append(seq_group)
  214. # Each sequence in the generation phase only takes one token slot.
  215. # Therefore, the number of batched tokens is equal to the number of
  216. # sequences in the RUNNING state.
  217. num_batched_tokens = sum(
  218. seq_group.num_seqs(status=SequenceStatus.RUNNING)
  219. for seq_group in self.running)
  220. scheduler_outputs = SchedulerOutputs(
  221. scheduled_seq_groups=self.running,
  222. prompt_run=False,
  223. num_batched_tokens=num_batched_tokens,
  224. blocks_to_swap_in=blocks_to_swap_in,
  225. blocks_to_swap_out=blocks_to_swap_out,
  226. blocks_to_copy=blocks_to_copy,
  227. ignored_seq_groups=[],
  228. )
  229. return scheduler_outputs
  230. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  231. # Schedule sequence groups.
  232. # This function call changes the internal states of the scheduler
  233. # such as self.running, self.swapped, and self.waiting.
  234. scheduler_outputs = self._schedule()
  235. # Create input data structures.
  236. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  237. for seq_group in scheduler_outputs.scheduled_seq_groups:
  238. seq_data: Dict[int, SequenceData] = {}
  239. block_tables: Dict[int, List[int]] = {}
  240. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  241. seq_id = seq.seq_id
  242. seq_data[seq_id] = seq.data
  243. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  244. seq_group_metadata = SequenceGroupMetadata(
  245. request_id=seq_group.request_id,
  246. is_prompt=scheduler_outputs.prompt_run,
  247. seq_data=seq_data,
  248. sampling_params=seq_group.sampling_params,
  249. block_tables=block_tables,
  250. )
  251. seq_group_metadata_list.append(seq_group_metadata)
  252. return seq_group_metadata_list, scheduler_outputs
  253. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  254. self.block_manager.fork(parent_seq, child_seq)
  255. def free_seq(self, seq: Sequence) -> None:
  256. self.block_manager.free(seq)
  257. def free_finished_seq_groups(self) -> None:
  258. self.running = [
  259. seq_group for seq_group in self.running
  260. if not seq_group.is_finished()
  261. ]
  262. def _allocate(self, seq_group: SequenceGroup) -> None:
  263. self.block_manager.allocate(seq_group)
  264. for seq in seq_group.get_seqs():
  265. seq.status = SequenceStatus.RUNNING
  266. def _append_slot(
  267. self,
  268. seq_group: SequenceGroup,
  269. blocks_to_copy: Dict[int, List[int]],
  270. ) -> None:
  271. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  272. ret = self.block_manager.append_slot(seq)
  273. if ret is not None:
  274. src_block, dst_block = ret
  275. if src_block in blocks_to_copy:
  276. blocks_to_copy[src_block].append(dst_block)
  277. else:
  278. blocks_to_copy[src_block] = [dst_block]
  279. def _preempt(
  280. self,
  281. seq_group: SequenceGroup,
  282. blocks_to_swap_out: Dict[int, int],
  283. preemption_mode: Optional[PreemptionMode] = None,
  284. ) -> None:
  285. # If preemption mode is not specified, we determine the mode as follows:
  286. # We use recomputation by default since it incurs lower overhead than
  287. # swapping. However, when the sequence group has multiple sequences
  288. # (e.g., beam search), recomputation is not currently supported. In
  289. # such a case, we use swapping instead.
  290. # FIXME: This makes our scheduling policy a bit bizarre.
  291. # As swapped sequences are prioritized over waiting sequences,
  292. # sequence groups with multiple sequences are implicitly prioritized
  293. # over sequence groups with a single sequence.
  294. # TODO: Support recomputation for sequence groups with multiple
  295. # sequences. This may require a more sophisticated CUDA kernel.
  296. if preemption_mode is None:
  297. if seq_group.get_max_num_running_seqs() == 1:
  298. preemption_mode = PreemptionMode.RECOMPUTE
  299. else:
  300. preemption_mode = PreemptionMode.SWAP
  301. if preemption_mode == PreemptionMode.RECOMPUTE:
  302. self._preempt_by_recompute(seq_group)
  303. elif preemption_mode == PreemptionMode.SWAP:
  304. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  305. else:
  306. assert False, "Invalid preemption mode."
  307. def _preempt_by_recompute(
  308. self,
  309. seq_group: SequenceGroup,
  310. ) -> None:
  311. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  312. assert len(seqs) == 1
  313. for seq in seqs:
  314. seq.status = SequenceStatus.WAITING
  315. self.block_manager.free(seq)
  316. # NOTE: For FCFS, we insert the preempted sequence group to the front
  317. # of the waiting queue.
  318. self.waiting.insert(0, seq_group)
  319. def _preempt_by_swap(
  320. self,
  321. seq_group: SequenceGroup,
  322. blocks_to_swap_out: Dict[int, int],
  323. ) -> None:
  324. self._swap_out(seq_group, blocks_to_swap_out)
  325. self.swapped.append(seq_group)
  326. def _swap_in(
  327. self,
  328. seq_group: SequenceGroup,
  329. blocks_to_swap_in: Dict[int, int],
  330. ) -> None:
  331. mapping = self.block_manager.swap_in(seq_group)
  332. blocks_to_swap_in.update(mapping)
  333. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  334. seq.status = SequenceStatus.RUNNING
  335. def _swap_out(
  336. self,
  337. seq_group: SequenceGroup,
  338. blocks_to_swap_out: Dict[int, int],
  339. ) -> None:
  340. if not self.block_manager.can_swap_out(seq_group):
  341. # FIXME: Abort the sequence group instead of aborting the
  342. # entire engine.
  343. raise RuntimeError(
  344. "Aborted due to the lack of CPU swap space. Please increase "
  345. "the swap space to avoid this error.")
  346. mapping = self.block_manager.swap_out(seq_group)
  347. blocks_to_swap_out.update(mapping)
  348. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  349. seq.status = SequenceStatus.SWAPPED