1
0

scheduler.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. from collections import deque
  2. import enum
  3. import time
  4. from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
  5. from loguru import logger
  6. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  7. from aphrodite.processing.block_manager import AllocStatus, BlockSpaceManager
  8. from aphrodite.processing.policy import PolicyFactory
  9. from aphrodite.lora.request import LoRARequest
  10. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  11. SequenceGroupMetadata, SequenceStatus)
  12. class PreemptionMode(enum.Enum):
  13. """Preemption modes.
  14. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  15. and swap them back in when the sequences are resumed.
  16. 2. Recomputation: Discard the blocks of the preempted sequences and
  17. recompute them when the sequences are resumed, treating the sequences as
  18. new prompts.
  19. """
  20. SWAP = enum.auto()
  21. RECOMPUTE = enum.auto()
  22. class SchedulerOutputs:
  23. def __init__(
  24. self,
  25. scheduled_seq_groups: Iterable[SequenceGroup],
  26. prompt_run: bool,
  27. num_batched_tokens: int,
  28. blocks_to_swap_in: Dict[int, int],
  29. blocks_to_swap_out: Dict[int, int],
  30. blocks_to_copy: Dict[int, List[int]],
  31. ignored_seq_groups: List[SequenceGroup],
  32. ) -> None:
  33. self.scheduled_seq_groups = scheduled_seq_groups
  34. self.prompt_run = prompt_run
  35. self.num_batched_tokens = num_batched_tokens
  36. self.blocks_to_swap_in = blocks_to_swap_in
  37. self.blocks_to_swap_out = blocks_to_swap_out
  38. self.blocks_to_copy = blocks_to_copy
  39. # Swap in and swap out should never happen at the same time.
  40. assert not (blocks_to_swap_in and blocks_to_swap_out)
  41. self.ignored_seq_groups = ignored_seq_groups
  42. self.num_loras = len(self.lora_requests)
  43. if self.num_loras > 0:
  44. self._sort_by_lora_ids()
  45. def is_empty(self) -> bool:
  46. # NOTE: We do not consider the ignored sequence groups.
  47. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  48. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  49. def _sort_by_lora_ids(self) -> bool:
  50. self.scheduled_seq_groups = sorted(
  51. self.scheduled_seq_groups,
  52. key=lambda g: (g.lora_request.lora_int_id
  53. if g.lora_request else 0, g.request_id))
  54. @property
  55. def lora_requests(self) -> Set[LoRARequest]:
  56. return {g.lora_request for g in self.scheduled_seq_groups}
  57. class Scheduler:
  58. def __init__(
  59. self,
  60. scheduler_config: SchedulerConfig,
  61. cache_config: CacheConfig,
  62. lora_config: Optional[LoRAConfig],
  63. ) -> None:
  64. self.scheduler_config = scheduler_config
  65. self.cache_config = cache_config
  66. # NOTE for LoRA scheduling: the current policy is extremely
  67. # simple and NOT fair. It can lead to starvation of some
  68. # LoRAs. This should be improved in the future.
  69. self.lora_config = lora_config
  70. self.prompt_limit = min(self.scheduler_config.max_model_len,
  71. self.scheduler_config.max_num_batched_tokens)
  72. # Instantiate the scheduling policy.
  73. self.policy = PolicyFactory.get_policy(policy_name="fcfs")
  74. # Create the block space manager.
  75. self.block_manager = BlockSpaceManager(
  76. block_size=self.cache_config.block_size,
  77. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  78. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  79. sliding_window=self.cache_config.sliding_window,
  80. enable_caching=self.cache_config.context_shift)
  81. # Sequence groups in the WAITING state.
  82. self.waiting: Deque[SequenceGroup] = deque()
  83. # Sequence groups in the RUNNING state.
  84. self.running: Deque[SequenceGroup] = deque()
  85. # Sequence groups in the SWAPPED state.
  86. self.swapped: Deque[SequenceGroup] = deque()
  87. @property
  88. def lora_enabled(self) -> bool:
  89. return bool(self.lora_config)
  90. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  91. # Add sequence groups to the waiting queue.
  92. self.waiting.append(seq_group)
  93. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  94. """Aborts a sequence group with the given ID.
  95. Check if the sequence group with the given ID
  96. is present in any of the state queue.
  97. If present, remove the sequence group from the state queue.
  98. Also, if any of the sequences in the sequence group is not finished,
  99. free the sequence with status `FINISHED_ABORTED`.
  100. Otherwise, do nothing.
  101. Args:
  102. request_id: The ID(s) of the sequence group to abort.
  103. """
  104. if isinstance(request_id, str):
  105. request_id = (request_id, )
  106. request_ids = set(request_id)
  107. for state_queue in [self.waiting, self.running, self.swapped]:
  108. aborted_groups: List[SequenceGroup] = []
  109. for seq_group in state_queue:
  110. if not request_ids:
  111. # Using 'break' here may add two extra iterations,
  112. # but is acceptable to reduce complexity .
  113. break
  114. if seq_group.request_id in request_ids:
  115. # Appending aborted group into pending list.
  116. aborted_groups.append(seq_group)
  117. request_ids.remove(seq_group.request_id)
  118. for aborted_group in aborted_groups:
  119. # Remove the sequence group from the state queue.
  120. state_queue.remove(aborted_group)
  121. for seq in aborted_group.get_seqs():
  122. if seq.is_finished():
  123. continue
  124. seq.status = SequenceStatus.FINISHED_ABORTED
  125. self.free_seq(seq)
  126. def has_unfinished_seqs(self) -> bool:
  127. return self.waiting or self.running or self.swapped
  128. def get_num_unfinished_seq_groups(self) -> int:
  129. return len(self.waiting) + len(self.running) + len(self.swapped)
  130. def _schedule(self) -> SchedulerOutputs:
  131. # Blocks that need to be swapped or copied before model execution.
  132. blocks_to_swap_in: Dict[int, int] = {}
  133. blocks_to_swap_out: Dict[int, int] = {}
  134. blocks_to_copy: Dict[int, List[int]] = {}
  135. # Fix the current time.
  136. now = time.monotonic()
  137. # Join waiting sequences if possible.
  138. if not self.swapped:
  139. ignored_seq_groups: List[SequenceGroup] = []
  140. scheduled: List[SequenceGroup] = []
  141. # The total number of sequences on the fly, including the
  142. # requests in the generation phase.
  143. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  144. for seq_group in self.running)
  145. curr_loras = set(
  146. seq_group.lora_int_id
  147. for seq_group in self.running) if self.lora_enabled else None
  148. seq_lens: List[int] = []
  149. # Optimization: We do not sort the waiting queue since the preempted
  150. # sequence groups are added to the front and the new sequence groups
  151. # are added to the back.
  152. leftover_waiting_sequences = deque()
  153. while self.waiting:
  154. seq_group = self.waiting[0]
  155. waiting_seqs = seq_group.get_seqs(
  156. status=SequenceStatus.WAITING)
  157. assert len(waiting_seqs) == 1, (
  158. "Waiting sequence group should have only one prompt "
  159. "sequence.")
  160. num_prompt_tokens = waiting_seqs[0].get_len()
  161. if num_prompt_tokens > self.prompt_limit:
  162. logger.warning(
  163. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  164. f" and exceeds limit of {self.prompt_limit}")
  165. for seq in waiting_seqs:
  166. seq.status = SequenceStatus.FINISHED_IGNORED
  167. ignored_seq_groups.append(seq_group)
  168. self.waiting.popleft()
  169. continue
  170. # If the sequence group cannot be allocated, stop.
  171. can_allocate = self.block_manager.can_allocate(seq_group)
  172. if can_allocate == AllocStatus.LATER:
  173. break
  174. elif can_allocate == AllocStatus.NEVER:
  175. logger.warning(
  176. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  177. f" and exceeds the capacity of block_manager")
  178. for seq in waiting_seqs:
  179. seq.status = SequenceStatus.FINISHED_IGNORED
  180. ignored_seq_groups.append(seq_group)
  181. self.waiting.popleft()
  182. continue
  183. lora_int_id = 0
  184. if self.lora_enabled:
  185. lora_int_id = seq_group.lora_int_id
  186. if lora_int_id > 0 and lora_int_id not in curr_loras and len(
  187. curr_loras) >= self.lora_config.max_loras:
  188. # We don't have a space for another LoRA, so
  189. # we ignore this request for now.
  190. leftover_waiting_sequences.appendleft(seq_group)
  191. self.waiting.popleft()
  192. continue
  193. # If the number of batched tokens exceeds the limit, stop.
  194. new_seq_lens = seq_lens + [num_prompt_tokens]
  195. num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
  196. if (num_batched_tokens >
  197. self.scheduler_config.max_num_batched_tokens):
  198. break
  199. # The total number of sequences in the RUNNING state should not
  200. # exceed the maximum number of sequences.
  201. num_new_seqs = seq_group.get_max_num_running_seqs()
  202. if (num_curr_seqs + num_new_seqs >
  203. self.scheduler_config.max_num_seqs):
  204. break
  205. num_paddings = num_batched_tokens - sum(new_seq_lens)
  206. if num_paddings > self.scheduler_config.max_paddings:
  207. break
  208. seq_lens = new_seq_lens
  209. if lora_int_id > 0:
  210. curr_loras.add(lora_int_id)
  211. self.waiting.popleft()
  212. self._allocate(seq_group)
  213. self.running.append(seq_group)
  214. num_curr_seqs += num_new_seqs
  215. scheduled.append(seq_group)
  216. self.waiting.extendleft(leftover_waiting_sequences)
  217. if scheduled or ignored_seq_groups:
  218. scheduler_outputs = SchedulerOutputs(
  219. scheduled_seq_groups=scheduled,
  220. prompt_run=True,
  221. num_batched_tokens=len(seq_lens) *
  222. max(seq_lens) if seq_lens else 0,
  223. blocks_to_swap_in=blocks_to_swap_in,
  224. blocks_to_swap_out=blocks_to_swap_out,
  225. blocks_to_copy=blocks_to_copy,
  226. ignored_seq_groups=ignored_seq_groups,
  227. )
  228. return scheduler_outputs
  229. # NOTE: Preemption happens only when there is no available slot
  230. # to keep all the sequence groups in the RUNNING state.
  231. # In this case, the policy is responsible for deciding which sequence
  232. # groups to preempt.
  233. self.running = self.policy.sort_by_priority(now, self.running)
  234. # Reserve new token slots for the running sequence groups.
  235. running: Deque[SequenceGroup] = deque()
  236. preempted: List[SequenceGroup] = []
  237. while self.running:
  238. seq_group = self.running.popleft()
  239. while not self.block_manager.can_append_slot(seq_group):
  240. if self.running:
  241. # Preempt the lowest-priority sequence groups.
  242. victim_seq_group = self.running.pop()
  243. self._preempt(victim_seq_group, blocks_to_swap_out)
  244. preempted.append(victim_seq_group)
  245. else:
  246. # No other sequence groups can be preempted.
  247. # Preempt the current sequence group.
  248. self._preempt(seq_group, blocks_to_swap_out)
  249. preempted.append(seq_group)
  250. break
  251. else:
  252. # Append new slots to the sequence group.
  253. self._append_slot(seq_group, blocks_to_copy)
  254. running.append(seq_group)
  255. self.running = running
  256. # Swap in the sequence groups in the SWAPPED state if possible.
  257. self.swapped = self.policy.sort_by_priority(now, self.swapped)
  258. if not preempted:
  259. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  260. for seq_group in self.running)
  261. curr_loras = set(
  262. seq_group.lora_int_id
  263. for seq_group in self.running) if self.lora_enabled else None
  264. leftover_swapped = deque()
  265. while self.swapped:
  266. seq_group = self.swapped[0]
  267. lora_int_id = 0
  268. if self.lora_enabled:
  269. lora_int_id = seq_group.lora_int_id
  270. if lora_int_id > 0 and lora_int_id not in curr_loras and len(
  271. curr_loras) >= self.lora_config.max_loras:
  272. # We don't have a space for another LoRA, so
  273. # we ignore this request for now.
  274. leftover_swapped.appendleft(seq_group)
  275. self.swapped.popleft()
  276. continue
  277. # If the sequence group cannot be swapped in, stop.
  278. if not self.block_manager.can_swap_in(seq_group):
  279. break
  280. # The total number of sequences in the RUNNING state should not
  281. # exceed the maximum number of sequences.
  282. num_new_seqs = seq_group.get_max_num_running_seqs()
  283. if (num_curr_seqs + num_new_seqs >
  284. self.scheduler_config.max_num_seqs):
  285. break
  286. if lora_int_id > 0:
  287. curr_loras.add(lora_int_id)
  288. self.swapped.popleft()
  289. self._swap_in(seq_group, blocks_to_swap_in)
  290. self._append_slot(seq_group, blocks_to_copy)
  291. num_curr_seqs += num_new_seqs
  292. self.running.append(seq_group)
  293. self.swapped.extendleft(leftover_swapped)
  294. # Each sequence in the generation phase only takes one token slot.
  295. # Therefore, the number of batched tokens is equal to the number of
  296. # sequences in the RUNNING state.
  297. num_batched_tokens = sum(
  298. seq_group.num_seqs(status=SequenceStatus.RUNNING)
  299. for seq_group in self.running)
  300. scheduler_outputs = SchedulerOutputs(
  301. scheduled_seq_groups=self.running,
  302. prompt_run=False,
  303. num_batched_tokens=num_batched_tokens,
  304. blocks_to_swap_in=blocks_to_swap_in,
  305. blocks_to_swap_out=blocks_to_swap_out,
  306. blocks_to_copy=blocks_to_copy,
  307. ignored_seq_groups=[],
  308. )
  309. return scheduler_outputs
  310. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  311. # Schedule sequence groups.
  312. # This function call changes the internal states of the scheduler
  313. # such as self.running, self.swapped, and self.waiting.
  314. scheduler_outputs = self._schedule()
  315. now = time.time()
  316. # Create input data structures.
  317. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  318. for seq_group in scheduler_outputs.scheduled_seq_groups:
  319. seq_group.maybe_set_first_scheduled_time(now)
  320. seq_data: Dict[int, SequenceData] = {}
  321. block_tables: Dict[int, List[int]] = {}
  322. persistent_data: Dict[int, dict] = {}
  323. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  324. seq_id = seq.seq_id
  325. seq_data[seq_id] = seq.data
  326. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  327. persistent_data[seq_id] = seq.persistent_data
  328. self.block_manager.access_all_blocks_in_seq(seq, now)
  329. seq_group_metadata = SequenceGroupMetadata(
  330. request_id=seq_group.request_id,
  331. is_prompt=scheduler_outputs.prompt_run,
  332. seq_data=seq_data,
  333. sampling_params=seq_group.sampling_params,
  334. block_tables=block_tables,
  335. lora_request=seq_group.lora_request,
  336. persistent_data=persistent_data,
  337. computed_block_nums=self.block_manager.
  338. get_common_computed_block_ids(seq_group),
  339. state=seq_group.state,
  340. )
  341. seq_group_metadata_list.append(seq_group_metadata)
  342. return seq_group_metadata_list, scheduler_outputs
  343. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  344. self.block_manager.fork(parent_seq, child_seq)
  345. def free_seq(self, seq: Sequence) -> None:
  346. self.block_manager.free(seq)
  347. def free_finished_seq_groups(self) -> None:
  348. self.running = deque(seq_group for seq_group in self.running
  349. if not seq_group.is_finished())
  350. def _allocate(self, seq_group: SequenceGroup) -> None:
  351. self.block_manager.allocate(seq_group)
  352. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  353. seq.status = SequenceStatus.RUNNING
  354. def _append_slot(
  355. self,
  356. seq_group: SequenceGroup,
  357. blocks_to_copy: Dict[int, List[int]],
  358. ) -> None:
  359. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  360. ret = self.block_manager.append_slot(seq)
  361. if ret is not None:
  362. src_block, dst_block = ret
  363. if src_block in blocks_to_copy:
  364. blocks_to_copy[src_block].append(dst_block)
  365. else:
  366. blocks_to_copy[src_block] = [dst_block]
  367. def _preempt(
  368. self,
  369. seq_group: SequenceGroup,
  370. blocks_to_swap_out: Dict[int, int],
  371. preemption_mode: Optional[PreemptionMode] = None,
  372. ) -> None:
  373. # If preemption mode is not specified, we determine the mode as follows:
  374. # We use recomputation by default since it incurs lower overhead than
  375. # swapping. However, when the sequence group has multiple sequences
  376. # (e.g., beam search), recomputation is not currently supported. In
  377. # such a case, we use swapping instead.
  378. # FIXME: This makes our scheduling policy a bit bizarre.
  379. # As swapped sequences are prioritized over waiting sequences,
  380. # sequence groups with multiple sequences are implicitly prioritized
  381. # over sequence groups with a single sequence.
  382. # TODO: Support recomputation for sequence groups with multiple
  383. # sequences. This may require a more sophisticated CUDA kernel.
  384. if preemption_mode is None:
  385. if seq_group.get_max_num_running_seqs() == 1:
  386. preemption_mode = PreemptionMode.RECOMPUTE
  387. else:
  388. preemption_mode = PreemptionMode.SWAP
  389. if preemption_mode == PreemptionMode.RECOMPUTE:
  390. self._preempt_by_recompute(seq_group)
  391. elif preemption_mode == PreemptionMode.SWAP:
  392. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  393. else:
  394. raise AssertionError("Invalid preemption mode.")
  395. def _preempt_by_recompute(
  396. self,
  397. seq_group: SequenceGroup,
  398. ) -> None:
  399. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  400. assert len(seqs) == 1
  401. for seq in seqs:
  402. seq.status = SequenceStatus.WAITING
  403. self.block_manager.free(seq)
  404. # NOTE: For FCFS, we insert the preempted sequence group to the front
  405. # of the waiting queue.
  406. self.waiting.appendleft(seq_group)
  407. def _preempt_by_swap(
  408. self,
  409. seq_group: SequenceGroup,
  410. blocks_to_swap_out: Dict[int, int],
  411. ) -> None:
  412. self._swap_out(seq_group, blocks_to_swap_out)
  413. self.swapped.append(seq_group)
  414. def _swap_in(
  415. self,
  416. seq_group: SequenceGroup,
  417. blocks_to_swap_in: Dict[int, int],
  418. ) -> None:
  419. mapping = self.block_manager.swap_in(seq_group)
  420. blocks_to_swap_in.update(mapping)
  421. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  422. seq.status = SequenceStatus.RUNNING
  423. def _swap_out(
  424. self,
  425. seq_group: SequenceGroup,
  426. blocks_to_swap_out: Dict[int, int],
  427. ) -> None:
  428. if not self.block_manager.can_swap_out(seq_group):
  429. # FIXME: Abort the sequence group instead of aborting the
  430. # entire engine.
  431. raise RuntimeError(
  432. "Aborted due to the lack of CPU swap space. Please increase "
  433. "the swap space to avoid this error.")
  434. mapping = self.block_manager.swap_out(seq_group)
  435. blocks_to_swap_out.update(mapping)
  436. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  437. seq.status = SequenceStatus.SWAPPED
  438. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  439. self.block_manager.mark_blocks_as_computed(seq_group)