1
0

scheduler.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. import enum
  2. import time
  3. from typing import Dict, Iterable, List, Optional, Tuple, Union
  4. from aphrodite.common.config import CacheConfig, SchedulerConfig
  5. from aphrodite.processing.block_manager import BlockSpaceManager
  6. from aphrodite.processing.policy import PolicyFactory
  7. from aphrodite.common.logger import init_logger
  8. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  9. SequenceGroupMetadata, SequenceStatus)
  10. logger = init_logger(__name__)
  11. class PreemptionMode(enum.Enum):
  12. """Preemption modes.
  13. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  14. and swap them back in when the sequences are resumed.
  15. 2. Recomputation: Discard the blocks of the preempted sequences and
  16. recompute them when the sequences are resumed, treating the sequences as
  17. new prompts.
  18. """
  19. SWAP = enum.auto()
  20. RECOMPUTE = enum.auto()
  21. class SchedulerOutputs:
  22. def __init__(
  23. self,
  24. scheduled_seq_groups: List[SequenceGroup],
  25. prompt_run: bool,
  26. num_batched_tokens: int,
  27. blocks_to_swap_in: Dict[int, int],
  28. blocks_to_swap_out: Dict[int, int],
  29. blocks_to_copy: Dict[int, List[int]],
  30. ignored_seq_groups: List[SequenceGroup],
  31. ) -> None:
  32. self.scheduled_seq_groups = scheduled_seq_groups
  33. self.prompt_run = prompt_run
  34. self.num_batched_tokens = num_batched_tokens
  35. self.blocks_to_swap_in = blocks_to_swap_in
  36. self.blocks_to_swap_out = blocks_to_swap_out
  37. self.blocks_to_copy = blocks_to_copy
  38. # Swap in and swap out should never happen at the same time.
  39. assert not (blocks_to_swap_in and blocks_to_swap_out)
  40. self.ignored_seq_groups = ignored_seq_groups
  41. def is_empty(self) -> bool:
  42. # NOTE: We do not consider the ignored sequence groups.
  43. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  44. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  45. class Scheduler:
  46. def __init__(
  47. self,
  48. scheduler_config: SchedulerConfig,
  49. cache_config: CacheConfig,
  50. ) -> None:
  51. self.scheduler_config = scheduler_config
  52. self.cache_config = cache_config
  53. self.prompt_limit = min(self.scheduler_config.max_model_len,
  54. self.scheduler_config.max_num_batched_tokens)
  55. # Instantiate the scheduling policy.
  56. self.policy = PolicyFactory.get_policy(policy_name="fcfs")
  57. # Create the block space manager.
  58. self.block_manager = BlockSpaceManager(
  59. block_size=self.cache_config.block_size,
  60. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  61. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  62. sliding_window=self.cache_config.sliding_window,
  63. )
  64. # TODO: Use deque instead of list for better performance.
  65. # Sequence groups in the WAITING state.
  66. self.waiting: List[SequenceGroup] = []
  67. # Sequence groups in the RUNNING state.
  68. self.running: List[SequenceGroup] = []
  69. # Sequence groups in the SWAPPED state.
  70. self.swapped: List[SequenceGroup] = []
  71. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  72. # Add sequence groups to the waiting queue.
  73. self.waiting.append(seq_group)
  74. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  75. if isinstance(request_id, str):
  76. request_id = (request_id, )
  77. request_ids = set(request_id)
  78. for state_queue in [self.waiting, self.running, self.swapped]:
  79. # We need to reverse the list as we are removing elements
  80. # from it as we iterate over it. If we don't do it,
  81. # indices will get messed up and we will skip over elements.
  82. for seq_group in reversed(state_queue):
  83. if seq_group.request_id in request_ids:
  84. # Remove the sequence group from the state queue.
  85. state_queue.remove(seq_group)
  86. for seq in seq_group.get_seqs():
  87. if seq.is_finished():
  88. continue
  89. seq.status = SequenceStatus.FINISHED_ABORTED
  90. self.free_seq(seq)
  91. request_ids.remove(seq_group.request_id)
  92. if not request_ids:
  93. return
  94. def has_unfinished_seqs(self) -> bool:
  95. return self.waiting or self.running or self.swapped
  96. def get_num_unfinished_seq_groups(self) -> int:
  97. return len(self.waiting) + len(self.running) + len(self.swapped)
  98. def _schedule(self) -> SchedulerOutputs:
  99. # Blocks that need to be swaped or copied before model execution.
  100. blocks_to_swap_in: Dict[int, int] = {}
  101. blocks_to_swap_out: Dict[int, int] = {}
  102. blocks_to_copy: Dict[int, List[int]] = {}
  103. # Fix the current time.
  104. now = time.time()
  105. # Join waiting sequences if possible.
  106. if not self.swapped:
  107. ignored_seq_groups: List[SequenceGroup] = []
  108. scheduled: List[SequenceGroup] = []
  109. # The total number of sequences on the fly, including the
  110. # requests in the generation phase.
  111. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  112. for seq_group in self.running)
  113. num_batched_tokens = 0
  114. # Optimization: We do not sort the waiting queue since the preempted
  115. # sequence groups are added to the front and the new sequence groups
  116. # are added to the back.
  117. while self.waiting:
  118. seq_group = self.waiting[0]
  119. assert seq_group.num_seqs() == 1, (
  120. "Waiting sequence group should have only one prompt "
  121. "sequence.")
  122. num_prompt_tokens = seq_group.get_seqs()[0].get_len()
  123. if num_prompt_tokens > self.prompt_limit:
  124. logger.warning(
  125. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  126. f" and exceeds limit of {self.prompt_limit}")
  127. for seq in seq_group.get_seqs():
  128. seq.status = SequenceStatus.FINISHED_IGNORED
  129. ignored_seq_groups.append(seq_group)
  130. self.waiting.pop(0)
  131. continue
  132. # If the sequence group cannot be allocated, stop.
  133. if not self.block_manager.can_allocate(seq_group):
  134. break
  135. # If the number of batched tokens exceeds the limit, stop.
  136. if (num_batched_tokens + num_prompt_tokens >
  137. self.scheduler_config.max_num_batched_tokens):
  138. break
  139. # The total number of sequences in the RUNNING state should not
  140. # exceed the maximum number of sequences.
  141. num_new_seqs = seq_group.get_max_num_running_seqs()
  142. if (num_curr_seqs + num_new_seqs >
  143. self.scheduler_config.max_num_seqs):
  144. break
  145. seq_group = self.waiting.pop(0)
  146. self._allocate(seq_group)
  147. self.running.append(seq_group)
  148. num_batched_tokens += num_prompt_tokens
  149. num_curr_seqs += num_new_seqs
  150. scheduled.append(seq_group)
  151. if scheduled or ignored_seq_groups:
  152. scheduler_outputs = SchedulerOutputs(
  153. scheduled_seq_groups=scheduled,
  154. prompt_run=True,
  155. num_batched_tokens=num_batched_tokens,
  156. blocks_to_swap_in=blocks_to_swap_in,
  157. blocks_to_swap_out=blocks_to_swap_out,
  158. blocks_to_copy=blocks_to_copy,
  159. ignored_seq_groups=ignored_seq_groups,
  160. )
  161. return scheduler_outputs
  162. # NOTE: Preemption happens only when there is no available slot
  163. # to keep all the sequence groups in the RUNNING state.
  164. # In this case, the policy is responsible for deciding which sequence
  165. # groups to preempt.
  166. self.running = self.policy.sort_by_priority(now, self.running)
  167. # Reserve new token slots for the running sequence groups.
  168. running: List[SequenceGroup] = []
  169. preempted: List[SequenceGroup] = []
  170. while self.running:
  171. seq_group = self.running.pop(0)
  172. while not self.block_manager.can_append_slot(seq_group):
  173. if self.running:
  174. # Preempt the lowest-priority sequence groups.
  175. victim_seq_group = self.running.pop(-1)
  176. self._preempt(victim_seq_group, blocks_to_swap_out)
  177. preempted.append(victim_seq_group)
  178. else:
  179. # No other sequence groups can be preempted.
  180. # Preempt the current sequence group.
  181. self._preempt(seq_group, blocks_to_swap_out)
  182. preempted.append(seq_group)
  183. break
  184. else:
  185. # Append new slots to the sequence group.
  186. self._append_slot(seq_group, blocks_to_copy)
  187. running.append(seq_group)
  188. self.running = running
  189. # Swap in the sequence groups in the SWAPPED state if possible.
  190. self.swapped = self.policy.sort_by_priority(now, self.swapped)
  191. if not preempted:
  192. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  193. for seq_group in self.running)
  194. while self.swapped:
  195. seq_group = self.swapped[0]
  196. # If the sequence group cannot be swapped in, stop.
  197. if not self.block_manager.can_swap_in(seq_group):
  198. break
  199. # The total number of sequences in the RUNNING state should not
  200. # exceed the maximum number of sequences.
  201. num_new_seqs = seq_group.get_max_num_running_seqs()
  202. if (num_curr_seqs + num_new_seqs >
  203. self.scheduler_config.max_num_seqs):
  204. break
  205. seq_group = self.swapped.pop(0)
  206. self._swap_in(seq_group, blocks_to_swap_in)
  207. self._append_slot(seq_group, blocks_to_copy)
  208. num_curr_seqs += num_new_seqs
  209. self.running.append(seq_group)
  210. # Each sequence in the generation phase only takes one token slot.
  211. # Therefore, the number of batched tokens is equal to the number of
  212. # sequences in the RUNNING state.
  213. num_batched_tokens = sum(
  214. seq_group.num_seqs(status=SequenceStatus.RUNNING)
  215. for seq_group in self.running)
  216. scheduler_outputs = SchedulerOutputs(
  217. scheduled_seq_groups=self.running,
  218. prompt_run=False,
  219. num_batched_tokens=num_batched_tokens,
  220. blocks_to_swap_in=blocks_to_swap_in,
  221. blocks_to_swap_out=blocks_to_swap_out,
  222. blocks_to_copy=blocks_to_copy,
  223. ignored_seq_groups=[],
  224. )
  225. return scheduler_outputs
  226. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  227. # Schedule sequence groups.
  228. # This function call changes the internal states of the scheduler
  229. # such as self.running, self.swapped, and self.waiting.
  230. scheduler_outputs = self._schedule()
  231. # Create input data structures.
  232. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  233. for seq_group in scheduler_outputs.scheduled_seq_groups:
  234. seq_data: Dict[int, List[SequenceData]] = {}
  235. block_tables: Dict[int, List[int]] = {}
  236. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  237. seq_id = seq.seq_id
  238. seq_data[seq_id] = seq.data
  239. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  240. seq_group_metadata = SequenceGroupMetadata(
  241. request_id=seq_group.request_id,
  242. is_prompt=scheduler_outputs.prompt_run,
  243. seq_data=seq_data,
  244. sampling_params=seq_group.sampling_params,
  245. block_tables=block_tables,
  246. )
  247. seq_group_metadata_list.append(seq_group_metadata)
  248. return seq_group_metadata_list, scheduler_outputs
  249. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  250. self.block_manager.fork(parent_seq, child_seq)
  251. def free_seq(self, seq: Sequence) -> None:
  252. self.block_manager.free(seq)
  253. def free_finished_seq_groups(self) -> None:
  254. self.running = [
  255. seq_group for seq_group in self.running
  256. if not seq_group.is_finished()
  257. ]
  258. def _allocate(self, seq_group: SequenceGroup) -> None:
  259. self.block_manager.allocate(seq_group)
  260. for seq in seq_group.get_seqs():
  261. seq.status = SequenceStatus.RUNNING
  262. def _append_slot(
  263. self,
  264. seq_group: SequenceGroup,
  265. blocks_to_copy: Dict[int, List[int]],
  266. ) -> None:
  267. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  268. ret = self.block_manager.append_slot(seq)
  269. if ret is not None:
  270. src_block, dst_block = ret
  271. if src_block in blocks_to_copy:
  272. blocks_to_copy[src_block].append(dst_block)
  273. else:
  274. blocks_to_copy[src_block] = [dst_block]
  275. def _preempt(
  276. self,
  277. seq_group: SequenceGroup,
  278. blocks_to_swap_out: Dict[int, int],
  279. preemption_mode: Optional[PreemptionMode] = None,
  280. ) -> None:
  281. # If preemption mode is not specified, we determine the mode as follows:
  282. # We use recomputation by default since it incurs lower overhead than
  283. # swapping. However, when the sequence group has multiple sequences
  284. # (e.g., beam search), recomputation is not currently supported. In
  285. # such a case, we use swapping instead.
  286. # FIXME: This makes our scheduling policy a bit bizarre.
  287. # As swapped sequences are prioritized over waiting sequences,
  288. # sequence groups with multiple sequences are implicitly prioritized
  289. # over sequence groups with a single sequence.
  290. # TODO: Support recomputation for sequence groups with multiple
  291. # sequences. This may require a more sophisticated CUDA kernel.
  292. if preemption_mode is None:
  293. if seq_group.get_max_num_running_seqs() == 1:
  294. preemption_mode = PreemptionMode.RECOMPUTE
  295. else:
  296. preemption_mode = PreemptionMode.SWAP
  297. if preemption_mode == PreemptionMode.RECOMPUTE:
  298. self._preempt_by_recompute(seq_group)
  299. elif preemption_mode == PreemptionMode.SWAP:
  300. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  301. else:
  302. assert False, "Invalid preemption mode."
  303. def _preempt_by_recompute(
  304. self,
  305. seq_group: SequenceGroup,
  306. ) -> None:
  307. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  308. assert len(seqs) == 1
  309. for seq in seqs:
  310. seq.status = SequenceStatus.WAITING
  311. self.block_manager.free(seq)
  312. # NOTE: For FCFS, we insert the preempted sequence group to the front
  313. # of the waiting queue.
  314. self.waiting.insert(0, seq_group)
  315. def _preempt_by_swap(
  316. self,
  317. seq_group: SequenceGroup,
  318. blocks_to_swap_out: Dict[int, int],
  319. ) -> None:
  320. self._swap_out(seq_group, blocks_to_swap_out)
  321. self.swapped.append(seq_group)
  322. def _swap_in(
  323. self,
  324. seq_group: SequenceGroup,
  325. blocks_to_swap_in: Dict[int, int],
  326. ) -> None:
  327. mapping = self.block_manager.swap_in(seq_group)
  328. blocks_to_swap_in.update(mapping)
  329. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  330. seq.status = SequenceStatus.RUNNING
  331. def _swap_out(
  332. self,
  333. seq_group: SequenceGroup,
  334. blocks_to_swap_out: Dict[int, int],
  335. ) -> None:
  336. if not self.block_manager.can_swap_out(seq_group):
  337. # FIXME: Abort the sequence group instead of aborting the
  338. # entire engine.
  339. raise RuntimeError(
  340. "Aborted due to the lack of CPU swap space. Please increase "
  341. "the swap space to avoid this error.")
  342. mapping = self.block_manager.swap_out(seq_group)
  343. blocks_to_swap_out.update(mapping)
  344. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  345. seq.status = SequenceStatus.SWAPPED