scheduler.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. from collections import deque
  2. import enum
  3. import time
  4. from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
  5. from aphrodite.common.config import LoRAConfig, CacheConfig, SchedulerConfig
  6. from aphrodite.processing.block_manager import AllocStatus, BlockSpaceManager
  7. from aphrodite.processing.policy import PolicyFactory
  8. from aphrodite.lora.request import LoRARequest
  9. from aphrodite.common.logger import init_logger
  10. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  11. SequenceGroupMetadata, SequenceStatus)
  12. from aphrodite.common.prefix import PrefixPool
  13. logger = init_logger(__name__)
  14. class PreemptionMode(enum.Enum):
  15. """Preemption modes.
  16. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  17. and swap them back in when the sequences are resumed.
  18. 2. Recomputation: Discard the blocks of the preempted sequences and
  19. recompute them when the sequences are resumed, treating the sequences as
  20. new prompts.
  21. """
  22. SWAP = enum.auto()
  23. RECOMPUTE = enum.auto()
  24. class SchedulerOutputs:
  25. def __init__(
  26. self,
  27. scheduled_seq_groups: Iterable[SequenceGroup],
  28. prompt_run: bool,
  29. num_batched_tokens: int,
  30. blocks_to_swap_in: Dict[int, int],
  31. blocks_to_swap_out: Dict[int, int],
  32. blocks_to_copy: Dict[int, List[int]],
  33. ignored_seq_groups: List[SequenceGroup],
  34. ) -> None:
  35. self.scheduled_seq_groups = scheduled_seq_groups
  36. self.prompt_run = prompt_run
  37. self.num_batched_tokens = num_batched_tokens
  38. self.blocks_to_swap_in = blocks_to_swap_in
  39. self.blocks_to_swap_out = blocks_to_swap_out
  40. self.blocks_to_copy = blocks_to_copy
  41. # Swap in and swap out should never happen at the same time.
  42. assert not (blocks_to_swap_in and blocks_to_swap_out)
  43. self.ignored_seq_groups = ignored_seq_groups
  44. self.num_loras = len(self.lora_requests)
  45. if self.num_loras > 0:
  46. self._sort_by_lora_ids()
  47. def is_empty(self) -> bool:
  48. # NOTE: We do not consider the ignored sequence groups.
  49. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  50. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  51. def _sort_by_lora_ids(self) -> bool:
  52. self.scheduled_seq_groups = sorted(
  53. self.scheduled_seq_groups,
  54. key=lambda g: (g.lora_request.lora_int_id
  55. if g.lora_request else 0, g.request_id))
  56. @property
  57. def lora_requests(self) -> Set[LoRARequest]:
  58. return {g.lora_request for g in self.scheduled_seq_groups}
  59. class Scheduler:
  60. def __init__(
  61. self,
  62. scheduler_config: SchedulerConfig,
  63. cache_config: CacheConfig,
  64. lora_config: Optional[LoRAConfig],
  65. ) -> None:
  66. self.scheduler_config = scheduler_config
  67. self.cache_config = cache_config
  68. # NOTE for LoRA scheduling: the current policy is extremely
  69. # simple and NOT fair. It can lead to starvation of some
  70. # LoRAs. This should be improved in the future.
  71. self.lora_config = lora_config
  72. self.prompt_limit = min(self.scheduler_config.max_model_len,
  73. self.scheduler_config.max_num_batched_tokens)
  74. # Instantiate the scheduling policy.
  75. self.policy = PolicyFactory.get_policy(policy_name="fcfs")
  76. # Create the block space manager.
  77. self.block_manager = BlockSpaceManager(
  78. block_size=self.cache_config.block_size,
  79. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  80. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  81. sliding_window=self.cache_config.sliding_window)
  82. # Create the prefix pool to cache the prefixes.
  83. self.prefix_pool = PrefixPool(self.cache_config.block_size)
  84. # Sequence groups in the WAITING state.
  85. self.waiting: Deque[SequenceGroup] = deque()
  86. # Sequence groups in the RUNNING state.
  87. self.running: Deque[SequenceGroup] = deque()
  88. # Sequence groups in the SWAPPED state.
  89. self.swapped: Deque[SequenceGroup] = deque()
  90. @property
  91. def lora_enabled(self) -> bool:
  92. return bool(self.lora_config)
  93. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  94. # Add sequence groups to the waiting queue.
  95. self.waiting.append(seq_group)
  96. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  97. """Aborts a sequence group with the given ID.
  98. Check if the sequence group with the given ID
  99. is present in any of the state queue.
  100. If present, remove the sequence group from the state queue.
  101. Also, if any of the sequences in the sequence group is not finished,
  102. free the sequence with status `FINISHED_ABORTED`.
  103. Otherwise, do nothing.
  104. Args:
  105. request_id: The ID(s) of the sequence group to abort.
  106. """
  107. if isinstance(request_id, str):
  108. request_id = (request_id, )
  109. request_ids = set(request_id)
  110. for state_queue in [self.waiting, self.running, self.swapped]:
  111. aborted_groups: List[SequenceGroup] = []
  112. for seq_group in state_queue:
  113. if not request_ids:
  114. # Using 'break' here may add two extra iterations,
  115. # but is acceptable to reduce complexity .
  116. break
  117. if seq_group.request_id in request_ids:
  118. # Appending aborted group into pending list.
  119. aborted_groups.append(seq_group)
  120. request_ids.remove(seq_group.request_id)
  121. for aborted_group in aborted_groups:
  122. # Remove the sequence group from the state queue.
  123. state_queue.remove(aborted_group)
  124. for seq in aborted_group.get_seqs():
  125. if seq.is_finished():
  126. continue
  127. seq.status = SequenceStatus.FINISHED_ABORTED
  128. self.free_seq(seq)
  129. def has_unfinished_seqs(self) -> bool:
  130. return self.waiting or self.running or self.swapped
  131. def get_num_unfinished_seq_groups(self) -> int:
  132. return len(self.waiting) + len(self.running) + len(self.swapped)
  133. def _schedule(self) -> SchedulerOutputs:
  134. # Blocks that need to be swaped or copied before model execution.
  135. blocks_to_swap_in: Dict[int, int] = {}
  136. blocks_to_swap_out: Dict[int, int] = {}
  137. blocks_to_copy: Dict[int, List[int]] = {}
  138. # Fix the current time.
  139. now = time.monotonic()
  140. # Join waiting sequences if possible.
  141. if not self.swapped:
  142. ignored_seq_groups: List[SequenceGroup] = []
  143. scheduled: List[SequenceGroup] = []
  144. # The total number of sequences on the fly, including the
  145. # requests in the generation phase.
  146. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  147. for seq_group in self.running)
  148. curr_loras = set(
  149. seq_group.lora_int_id
  150. for seq_group in self.running) if self.lora_enabled else None
  151. seq_lens: List[int] = []
  152. # Optimization: We do not sort the waiting queue since the preempted
  153. # sequence groups are added to the front and the new sequence groups
  154. # are added to the back.
  155. leftover_waiting_sequences = deque()
  156. while self.waiting:
  157. seq_group = self.waiting[0]
  158. waiting_seqs = seq_group.get_seqs(
  159. status=SequenceStatus.WAITING)
  160. assert len(waiting_seqs) == 1, (
  161. "Waiting sequence group should have only one prompt "
  162. "sequence.")
  163. num_prompt_tokens = waiting_seqs[0].get_len()
  164. if num_prompt_tokens > self.prompt_limit:
  165. logger.warning(
  166. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  167. f" and exceeds limit of {self.prompt_limit}")
  168. for seq in waiting_seqs:
  169. seq.status = SequenceStatus.FINISHED_IGNORED
  170. ignored_seq_groups.append(seq_group)
  171. self.waiting.popleft()
  172. continue
  173. # If the sequence group cannot be allocated, stop.
  174. can_allocate = self.block_manager.can_allocate(seq_group)
  175. if can_allocate == AllocStatus.LATER:
  176. break
  177. elif can_allocate == AllocStatus.NEVER:
  178. logger.warning(
  179. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  180. f" and exceeds the capacity of block_manager")
  181. for seq in waiting_seqs:
  182. seq.status = SequenceStatus.FINISHED_IGNORED
  183. ignored_seq_groups.append(seq_group)
  184. self.waiting.popleft()
  185. continue
  186. lora_int_id = 0
  187. if self.lora_enabled:
  188. lora_int_id = seq_group.lora_int_id
  189. if (lora_int_id > 0 and lora_int_id not in curr_loras
  190. and len(curr_loras) >= self.lora_config.max_loras):
  191. # We don't have a space for another LoRA, so
  192. # we ignore this request for now.
  193. leftover_waiting_sequences.appendleft(seq_group)
  194. self.waiting.popleft()
  195. continue
  196. # If the number of batched tokens exceeds the limit, stop.
  197. new_seq_lens = seq_lens + [num_prompt_tokens]
  198. num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
  199. if (num_batched_tokens >
  200. self.scheduler_config.max_num_batched_tokens):
  201. break
  202. # The total number of sequences in the RUNNING state should not
  203. # exceed the maximum number of sequences.
  204. num_new_seqs = seq_group.get_max_num_running_seqs()
  205. if (num_curr_seqs + num_new_seqs >
  206. self.scheduler_config.max_num_seqs):
  207. break
  208. num_paddings = num_batched_tokens - sum(new_seq_lens)
  209. if num_paddings > self.scheduler_config.max_paddings:
  210. break
  211. seq_lens = new_seq_lens
  212. if lora_int_id > 0:
  213. curr_loras.add(lora_int_id)
  214. self.waiting.popleft()
  215. self._allocate(seq_group)
  216. self.running.append(seq_group)
  217. num_curr_seqs += num_new_seqs
  218. scheduled.append(seq_group)
  219. self.waiting.extendleft(leftover_waiting_sequences)
  220. if scheduled or ignored_seq_groups:
  221. scheduler_outputs = SchedulerOutputs(
  222. scheduled_seq_groups=scheduled,
  223. prompt_run=True,
  224. num_batched_tokens=len(seq_lens) *
  225. max(seq_lens) if seq_lens else 0,
  226. blocks_to_swap_in=blocks_to_swap_in,
  227. blocks_to_swap_out=blocks_to_swap_out,
  228. blocks_to_copy=blocks_to_copy,
  229. ignored_seq_groups=ignored_seq_groups,
  230. )
  231. return scheduler_outputs
  232. # NOTE: Preemption happens only when there is no available slot
  233. # to keep all the sequence groups in the RUNNING state.
  234. # In this case, the policy is responsible for deciding which sequence
  235. # groups to preempt.
  236. self.running = self.policy.sort_by_priority(now, self.running)
  237. # Reserve new token slots for the running sequence groups.
  238. running: Deque[SequenceGroup] = deque()
  239. preempted: List[SequenceGroup] = []
  240. while self.running:
  241. seq_group = self.running.popleft()
  242. while not self.block_manager.can_append_slot(seq_group):
  243. if self.running:
  244. # Preempt the lowest-priority sequence groups.
  245. victim_seq_group = self.running.pop()
  246. self._preempt(victim_seq_group, blocks_to_swap_out)
  247. preempted.append(victim_seq_group)
  248. else:
  249. # No other sequence groups can be preempted.
  250. # Preempt the current sequence group.
  251. self._preempt(seq_group, blocks_to_swap_out)
  252. preempted.append(seq_group)
  253. break
  254. else:
  255. # Append new slots to the sequence group.
  256. self._append_slot(seq_group, blocks_to_copy)
  257. running.append(seq_group)
  258. self.running = running
  259. # Swap in the sequence groups in the SWAPPED state if possible.
  260. self.swapped = self.policy.sort_by_priority(now, self.swapped)
  261. if not preempted:
  262. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  263. for seq_group in self.running)
  264. curr_loras = set(
  265. seq_group.lora_int_id
  266. for seq_group in self.running) if self.lora_enabled else None
  267. leftover_swapped = deque()
  268. while self.swapped:
  269. seq_group = self.swapped[0]
  270. lora_int_id = 0
  271. if self.lora_enabled:
  272. lora_int_id = seq_group.lora_int_id
  273. if (lora_int_id > 0 and lora_int_id not in curr_loras
  274. and len(curr_loras) >= self.lora_config.max_loras):
  275. # We don't have a space for another LoRA, so
  276. # we ignore this request for now.
  277. leftover_swapped.appendleft(seq_group)
  278. self.swapped.popleft()
  279. continue
  280. # If the sequence group cannot be swapped in, stop.
  281. if not self.block_manager.can_swap_in(seq_group):
  282. break
  283. # The total number of sequences in the RUNNING state should not
  284. # exceed the maximum number of sequences.
  285. num_new_seqs = seq_group.get_max_num_running_seqs()
  286. if (num_curr_seqs + num_new_seqs >
  287. self.scheduler_config.max_num_seqs):
  288. break
  289. if lora_int_id > 0:
  290. curr_loras.add(lora_int_id)
  291. self.swapped.popleft()
  292. self._swap_in(seq_group, blocks_to_swap_in)
  293. self._append_slot(seq_group, blocks_to_copy)
  294. num_curr_seqs += num_new_seqs
  295. self.running.append(seq_group)
  296. self.swapped.extendleft(leftover_swapped)
  297. # Each sequence in the generation phase only takes one token slot.
  298. # Therefore, the number of batched tokens is equal to the number of
  299. # sequences in the RUNNING state.
  300. num_batched_tokens = sum(
  301. seq_group.num_seqs(status=SequenceStatus.RUNNING)
  302. for seq_group in self.running)
  303. scheduler_outputs = SchedulerOutputs(
  304. scheduled_seq_groups=self.running,
  305. prompt_run=False,
  306. num_batched_tokens=num_batched_tokens,
  307. blocks_to_swap_in=blocks_to_swap_in,
  308. blocks_to_swap_out=blocks_to_swap_out,
  309. blocks_to_copy=blocks_to_copy,
  310. ignored_seq_groups=[],
  311. )
  312. return scheduler_outputs
  313. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  314. # Schedule sequence groups.
  315. # This function call changes the internal states of the scheduler
  316. # such as self.running, self.swapped, and self.waiting.
  317. scheduler_outputs = self._schedule()
  318. # Create input data structures.
  319. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  320. for seq_group in scheduler_outputs.scheduled_seq_groups:
  321. seq_data: Dict[int, SequenceData] = {}
  322. block_tables: Dict[int, List[int]] = {}
  323. persistent_data: Dict[int, dict] = {}
  324. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  325. seq_id = seq.seq_id
  326. seq_data[seq_id] = seq.data
  327. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  328. persistent_data[seq_id] = seq.persistent_data
  329. seq_group_metadata = SequenceGroupMetadata(
  330. request_id=seq_group.request_id,
  331. is_prompt=scheduler_outputs.prompt_run,
  332. seq_data=seq_data,
  333. sampling_params=seq_group.sampling_params,
  334. block_tables=block_tables,
  335. lora_request=seq_group.lora_request,
  336. persistent_data=persistent_data,
  337. prefix=seq_group.prefix,
  338. )
  339. seq_group_metadata_list.append(seq_group_metadata)
  340. return seq_group_metadata_list, scheduler_outputs
  341. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  342. self.block_manager.fork(parent_seq, child_seq)
  343. def free_seq(self, seq: Sequence) -> None:
  344. self.block_manager.free(seq)
  345. def free_finished_seq_groups(self) -> None:
  346. self.running = deque(seq_group for seq_group in self.running
  347. if not seq_group.is_finished())
  348. def _allocate(self, seq_group: SequenceGroup) -> None:
  349. self.block_manager.allocate(seq_group)
  350. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  351. seq.status = SequenceStatus.RUNNING
  352. def _append_slot(
  353. self,
  354. seq_group: SequenceGroup,
  355. blocks_to_copy: Dict[int, List[int]],
  356. ) -> None:
  357. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  358. ret = self.block_manager.append_slot(seq)
  359. if ret is not None:
  360. src_block, dst_block = ret
  361. if src_block in blocks_to_copy:
  362. blocks_to_copy[src_block].append(dst_block)
  363. else:
  364. blocks_to_copy[src_block] = [dst_block]
  365. def _preempt(
  366. self,
  367. seq_group: SequenceGroup,
  368. blocks_to_swap_out: Dict[int, int],
  369. preemption_mode: Optional[PreemptionMode] = None,
  370. ) -> None:
  371. # If preemption mode is not specified, we determine the mode as follows:
  372. # We use recomputation by default since it incurs lower overhead than
  373. # swapping. However, when the sequence group has multiple sequences
  374. # (e.g., beam search), recomputation is not currently supported. In
  375. # such a case, we use swapping instead.
  376. # FIXME: This makes our scheduling policy a bit bizarre.
  377. # As swapped sequences are prioritized over waiting sequences,
  378. # sequence groups with multiple sequences are implicitly prioritized
  379. # over sequence groups with a single sequence.
  380. # TODO: Support recomputation for sequence groups with multiple
  381. # sequences. This may require a more sophisticated CUDA kernel.
  382. if preemption_mode is None:
  383. if seq_group.get_max_num_running_seqs() == 1:
  384. preemption_mode = PreemptionMode.RECOMPUTE
  385. else:
  386. preemption_mode = PreemptionMode.SWAP
  387. if preemption_mode == PreemptionMode.RECOMPUTE:
  388. self._preempt_by_recompute(seq_group)
  389. elif preemption_mode == PreemptionMode.SWAP:
  390. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  391. else:
  392. raise AssertionError("Invalid preemption mode.")
  393. def _preempt_by_recompute(
  394. self,
  395. seq_group: SequenceGroup,
  396. ) -> None:
  397. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  398. assert len(seqs) == 1
  399. for seq in seqs:
  400. seq.status = SequenceStatus.WAITING
  401. self.block_manager.free(seq)
  402. # NOTE: For FCFS, we insert the preempted sequence group to the front
  403. # of the waiting queue.
  404. self.waiting.appendleft(seq_group)
  405. def _preempt_by_swap(
  406. self,
  407. seq_group: SequenceGroup,
  408. blocks_to_swap_out: Dict[int, int],
  409. ) -> None:
  410. self._swap_out(seq_group, blocks_to_swap_out)
  411. self.swapped.append(seq_group)
  412. def _swap_in(
  413. self,
  414. seq_group: SequenceGroup,
  415. blocks_to_swap_in: Dict[int, int],
  416. ) -> None:
  417. mapping = self.block_manager.swap_in(seq_group)
  418. blocks_to_swap_in.update(mapping)
  419. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  420. seq.status = SequenceStatus.RUNNING
  421. def _swap_out(
  422. self,
  423. seq_group: SequenceGroup,
  424. blocks_to_swap_out: Dict[int, int],
  425. ) -> None:
  426. if not self.block_manager.can_swap_out(seq_group):
  427. # FIXME: Abort the sequence group instead of aborting the
  428. # entire engine.
  429. raise RuntimeError(
  430. "Aborted due to the lack of CPU swap space. Please increase "
  431. "the swap space to avoid this error.")
  432. mapping = self.block_manager.swap_out(seq_group)
  433. blocks_to_swap_out.update(mapping)
  434. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  435. seq.status = SequenceStatus.SWAPPED