scheduler.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. from collections import deque
  2. import enum
  3. import time
  4. from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
  5. from loguru import logger
  6. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  7. from aphrodite.processing.block_manager import AllocStatus, BlockSpaceManager
  8. from aphrodite.processing.policy import PolicyFactory
  9. from aphrodite.lora.request import LoRARequest
  10. from aphrodite.common.sequence import (
  11. Sequence,
  12. SequenceData,
  13. SequenceGroup,
  14. SequenceGroupMetadata,
  15. SequenceStatus,
  16. )
  17. class PreemptionMode(enum.Enum):
  18. """Preemption modes.
  19. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  20. and swap them back in when the sequences are resumed.
  21. 2. Recomputation: Discard the blocks of the preempted sequences and
  22. recompute them when the sequences are resumed, treating the sequences as
  23. new prompts.
  24. """
  25. SWAP = enum.auto()
  26. RECOMPUTE = enum.auto()
  27. class SchedulerOutputs:
  28. def __init__(
  29. self,
  30. scheduled_seq_groups: Iterable[SequenceGroup],
  31. prompt_run: bool,
  32. num_batched_tokens: int,
  33. blocks_to_swap_in: Dict[int, int],
  34. blocks_to_swap_out: Dict[int, int],
  35. blocks_to_copy: Dict[int, List[int]],
  36. ignored_seq_groups: List[SequenceGroup],
  37. ) -> None:
  38. self.scheduled_seq_groups = scheduled_seq_groups
  39. self.prompt_run = prompt_run
  40. self.num_batched_tokens = num_batched_tokens
  41. self.blocks_to_swap_in = blocks_to_swap_in
  42. self.blocks_to_swap_out = blocks_to_swap_out
  43. self.blocks_to_copy = blocks_to_copy
  44. # Swap in and swap out should never happen at the same time.
  45. assert not (blocks_to_swap_in and blocks_to_swap_out)
  46. self.ignored_seq_groups = ignored_seq_groups
  47. self.num_loras = len(self.lora_requests)
  48. if self.num_loras > 0:
  49. self._sort_by_lora_ids()
  50. def is_empty(self) -> bool:
  51. # NOTE: We do not consider the ignored sequence groups.
  52. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  53. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  54. def _sort_by_lora_ids(self) -> bool:
  55. self.scheduled_seq_groups = sorted(
  56. self.scheduled_seq_groups,
  57. key=lambda g: (g.lora_int_id, g.request_id),
  58. )
  59. @property
  60. def lora_requests(self) -> Set[LoRARequest]:
  61. return {g.lora_request for g in self.scheduled_seq_groups}
  62. class Scheduler:
  63. def __init__(
  64. self,
  65. scheduler_config: SchedulerConfig,
  66. cache_config: CacheConfig,
  67. lora_config: Optional[LoRAConfig],
  68. ) -> None:
  69. self.scheduler_config = scheduler_config
  70. self.cache_config = cache_config
  71. # NOTE for LoRA scheduling: the current policy is extremely
  72. # simple and NOT fair. It can lead to starvation of some
  73. # LoRAs. This should be improved in the future.
  74. self.lora_config = lora_config
  75. self.prompt_limit = min(
  76. self.scheduler_config.max_model_len,
  77. self.scheduler_config.max_num_batched_tokens,
  78. )
  79. # Instantiate the scheduling policy.
  80. self.policy = PolicyFactory.get_policy(policy_name="fcfs")
  81. # Create the block space manager.
  82. self.block_manager = BlockSpaceManager(
  83. block_size=self.cache_config.block_size,
  84. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  85. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  86. sliding_window=self.cache_config.sliding_window,
  87. enable_caching=self.cache_config.context_shift,
  88. )
  89. # Sequence groups in the WAITING state.
  90. self.waiting: Deque[SequenceGroup] = deque()
  91. # Sequence groups in the RUNNING state.
  92. self.running: Deque[SequenceGroup] = deque()
  93. # Sequence groups in the SWAPPED state.
  94. self.swapped: Deque[SequenceGroup] = deque()
  95. @property
  96. def lora_enabled(self) -> bool:
  97. return bool(self.lora_config)
  98. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  99. # Add sequence groups to the waiting queue.
  100. self.waiting.append(seq_group)
  101. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  102. """Aborts a sequence group with the given ID.
  103. Check if the sequence group with the given ID
  104. is present in any of the state queue.
  105. If present, remove the sequence group from the state queue.
  106. Also, if any of the sequences in the sequence group is not finished,
  107. free the sequence with status `FINISHED_ABORTED`.
  108. Otherwise, do nothing.
  109. Args:
  110. request_id: The ID(s) of the sequence group to abort.
  111. """
  112. if isinstance(request_id, str):
  113. request_id = (request_id, )
  114. request_ids = set(request_id)
  115. for state_queue in [self.waiting, self.running, self.swapped]:
  116. aborted_groups: List[SequenceGroup] = []
  117. for seq_group in state_queue:
  118. if not request_ids:
  119. # Using 'break' here may add two extra iterations,
  120. # but is acceptable to reduce complexity .
  121. break
  122. if seq_group.request_id in request_ids:
  123. # Appending aborted group into pending list.
  124. aborted_groups.append(seq_group)
  125. request_ids.remove(seq_group.request_id)
  126. for aborted_group in aborted_groups:
  127. # Remove the sequence group from the state queue.
  128. state_queue.remove(aborted_group)
  129. for seq in aborted_group.get_seqs():
  130. if seq.is_finished():
  131. continue
  132. seq.status = SequenceStatus.FINISHED_ABORTED
  133. self.free_seq(seq)
  134. def has_unfinished_seqs(self) -> bool:
  135. return self.waiting or self.running or self.swapped
  136. def get_num_unfinished_seq_groups(self) -> int:
  137. return len(self.waiting) + len(self.running) + len(self.swapped)
  138. def _schedule(self) -> SchedulerOutputs:
  139. # Blocks that need to be swapped or copied before model execution.
  140. blocks_to_swap_in: Dict[int, int] = {}
  141. blocks_to_swap_out: Dict[int, int] = {}
  142. blocks_to_copy: Dict[int, List[int]] = {}
  143. # Fix the current time.
  144. now = time.monotonic()
  145. # Join waiting sequences if possible.
  146. if not self.swapped:
  147. ignored_seq_groups: List[SequenceGroup] = []
  148. scheduled: List[SequenceGroup] = []
  149. # The total number of sequences on the fly, including the
  150. # requests in the generation phase.
  151. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  152. for seq_group in self.running)
  153. curr_loras = (set(
  154. seq_group.lora_int_id
  155. for seq_group in self.running) if self.lora_enabled else None)
  156. seq_lens: List[int] = []
  157. # Optimization: We do not sort the waiting queue since the preempted
  158. # sequence groups are added to the front and the new sequence groups
  159. # are added to the back.
  160. leftover_waiting_sequences = deque()
  161. while self.waiting:
  162. seq_group = self.waiting[0]
  163. waiting_seqs = seq_group.get_seqs(
  164. status=SequenceStatus.WAITING)
  165. assert len(waiting_seqs) == 1, (
  166. "Waiting sequence group should have only one prompt "
  167. "sequence.")
  168. num_prompt_tokens = waiting_seqs[0].get_len()
  169. if num_prompt_tokens > self.prompt_limit:
  170. logger.warning(
  171. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  172. f" and exceeds limit of {self.prompt_limit}")
  173. for seq in waiting_seqs:
  174. seq.status = SequenceStatus.FINISHED_IGNORED
  175. ignored_seq_groups.append(seq_group)
  176. self.waiting.popleft()
  177. continue
  178. # If the sequence group cannot be allocated, stop.
  179. can_allocate = self.block_manager.can_allocate(seq_group)
  180. if can_allocate == AllocStatus.LATER:
  181. break
  182. elif can_allocate == AllocStatus.NEVER:
  183. logger.warning(
  184. f"Input prompt ({num_prompt_tokens} tokens) is too long"
  185. f" and exceeds the capacity of block_manager")
  186. for seq in waiting_seqs:
  187. seq.status = SequenceStatus.FINISHED_IGNORED
  188. ignored_seq_groups.append(seq_group)
  189. self.waiting.popleft()
  190. continue
  191. lora_int_id = 0
  192. if self.lora_enabled:
  193. lora_int_id = seq_group.lora_int_id
  194. if (lora_int_id > 0 and lora_int_id not in curr_loras
  195. and len(curr_loras) >= self.lora_config.max_loras):
  196. # We don't have a space for another LoRA, so
  197. # we ignore this request for now.
  198. leftover_waiting_sequences.appendleft(seq_group)
  199. self.waiting.popleft()
  200. continue
  201. # If the number of batched tokens exceeds the limit, stop.
  202. new_seq_lens = seq_lens + [num_prompt_tokens]
  203. num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
  204. if (num_batched_tokens >
  205. self.scheduler_config.max_num_batched_tokens):
  206. break
  207. # The total number of sequences in the RUNNING state should not
  208. # exceed the maximum number of sequences.
  209. num_new_seqs = seq_group.get_max_num_running_seqs()
  210. if (num_curr_seqs + num_new_seqs >
  211. self.scheduler_config.max_num_seqs):
  212. break
  213. num_paddings = num_batched_tokens - sum(new_seq_lens)
  214. if num_paddings > self.scheduler_config.max_paddings:
  215. break
  216. seq_lens = new_seq_lens
  217. if lora_int_id > 0:
  218. curr_loras.add(lora_int_id)
  219. self.waiting.popleft()
  220. self._allocate(seq_group)
  221. self.running.append(seq_group)
  222. num_curr_seqs += num_new_seqs
  223. scheduled.append(seq_group)
  224. self.waiting.extendleft(leftover_waiting_sequences)
  225. if scheduled or ignored_seq_groups:
  226. scheduler_outputs = SchedulerOutputs(
  227. scheduled_seq_groups=scheduled,
  228. prompt_run=True,
  229. num_batched_tokens=len(seq_lens) *
  230. max(seq_lens) if seq_lens else 0,
  231. blocks_to_swap_in=blocks_to_swap_in,
  232. blocks_to_swap_out=blocks_to_swap_out,
  233. blocks_to_copy=blocks_to_copy,
  234. ignored_seq_groups=ignored_seq_groups,
  235. )
  236. return scheduler_outputs
  237. # NOTE: Preemption happens only when there is no available slot
  238. # to keep all the sequence groups in the RUNNING state.
  239. # In this case, the policy is responsible for deciding which sequence
  240. # groups to preempt.
  241. self.running = self.policy.sort_by_priority(now, self.running)
  242. # Reserve new token slots for the running sequence groups.
  243. running: Deque[SequenceGroup] = deque()
  244. preempted: List[SequenceGroup] = []
  245. while self.running:
  246. seq_group = self.running.popleft()
  247. while not self.block_manager.can_append_slot(seq_group):
  248. if self.running:
  249. # Preempt the lowest-priority sequence groups.
  250. victim_seq_group = self.running.pop()
  251. self._preempt(victim_seq_group, blocks_to_swap_out)
  252. preempted.append(victim_seq_group)
  253. else:
  254. # No other sequence groups can be preempted.
  255. # Preempt the current sequence group.
  256. self._preempt(seq_group, blocks_to_swap_out)
  257. preempted.append(seq_group)
  258. break
  259. else:
  260. # Append new slots to the sequence group.
  261. self._append_slot(seq_group, blocks_to_copy)
  262. running.append(seq_group)
  263. self.running = running
  264. # Swap in the sequence groups in the SWAPPED state if possible.
  265. self.swapped = self.policy.sort_by_priority(now, self.swapped)
  266. if not preempted:
  267. num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
  268. for seq_group in self.running)
  269. curr_loras = (set(
  270. seq_group.lora_int_id
  271. for seq_group in self.running) if self.lora_enabled else None)
  272. leftover_swapped = deque()
  273. while self.swapped:
  274. seq_group = self.swapped[0]
  275. lora_int_id = 0
  276. if self.lora_enabled:
  277. lora_int_id = seq_group.lora_int_id
  278. if (lora_int_id > 0 and lora_int_id not in curr_loras
  279. and len(curr_loras) >= self.lora_config.max_loras):
  280. # We don't have a space for another LoRA, so
  281. # we ignore this request for now.
  282. leftover_swapped.appendleft(seq_group)
  283. self.swapped.popleft()
  284. continue
  285. # If the sequence group cannot be swapped in, stop.
  286. if not self.block_manager.can_swap_in(seq_group):
  287. break
  288. # The total number of sequences in the RUNNING state should not
  289. # exceed the maximum number of sequences.
  290. num_new_seqs = seq_group.get_max_num_running_seqs()
  291. if (num_curr_seqs + num_new_seqs >
  292. self.scheduler_config.max_num_seqs):
  293. break
  294. if lora_int_id > 0:
  295. curr_loras.add(lora_int_id)
  296. self.swapped.popleft()
  297. self._swap_in(seq_group, blocks_to_swap_in)
  298. self._append_slot(seq_group, blocks_to_copy)
  299. num_curr_seqs += num_new_seqs
  300. self.running.append(seq_group)
  301. self.swapped.extendleft(leftover_swapped)
  302. # Each sequence in the generation phase only takes one token slot.
  303. # Therefore, the number of batched tokens is equal to the number of
  304. # sequences in the RUNNING state.
  305. num_batched_tokens = sum(
  306. seq_group.num_seqs(status=SequenceStatus.RUNNING)
  307. for seq_group in self.running)
  308. scheduler_outputs = SchedulerOutputs(
  309. scheduled_seq_groups=self.running,
  310. prompt_run=False,
  311. num_batched_tokens=num_batched_tokens,
  312. blocks_to_swap_in=blocks_to_swap_in,
  313. blocks_to_swap_out=blocks_to_swap_out,
  314. blocks_to_copy=blocks_to_copy,
  315. ignored_seq_groups=[],
  316. )
  317. return scheduler_outputs
  318. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  319. # Schedule sequence groups.
  320. # This function call changes the internal states of the scheduler
  321. # such as self.running, self.swapped, and self.waiting.
  322. scheduler_outputs = self._schedule()
  323. now = time.time()
  324. # Create input data structures.
  325. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  326. for seq_group in scheduler_outputs.scheduled_seq_groups:
  327. seq_group.maybe_set_first_scheduled_time(now)
  328. seq_data: Dict[int, SequenceData] = {}
  329. block_tables: Dict[int, List[int]] = {}
  330. persistent_data: Dict[int, dict] = {}
  331. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  332. seq_id = seq.seq_id
  333. seq_data[seq_id] = seq.data
  334. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  335. persistent_data[seq_id] = seq.persistent_data
  336. self.block_manager.access_all_blocks_in_seq(seq, now)
  337. seq_group_metadata = SequenceGroupMetadata(
  338. request_id=seq_group.request_id,
  339. is_prompt=scheduler_outputs.prompt_run,
  340. seq_data=seq_data,
  341. sampling_params=seq_group.sampling_params,
  342. block_tables=block_tables,
  343. lora_request=seq_group.lora_request,
  344. persistent_data=persistent_data,
  345. computed_block_nums=self.block_manager.
  346. get_common_computed_block_ids(seq_group),
  347. state=seq_group.state,
  348. )
  349. seq_group_metadata_list.append(seq_group_metadata)
  350. return seq_group_metadata_list, scheduler_outputs
  351. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  352. self.block_manager.fork(parent_seq, child_seq)
  353. def free_seq(self, seq: Sequence) -> None:
  354. self.block_manager.free(seq)
  355. def free_finished_seq_groups(self) -> None:
  356. self.running = deque(seq_group for seq_group in self.running
  357. if not seq_group.is_finished())
  358. def _allocate(self, seq_group: SequenceGroup) -> None:
  359. self.block_manager.allocate(seq_group)
  360. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  361. seq.status = SequenceStatus.RUNNING
  362. def _append_slot(
  363. self,
  364. seq_group: SequenceGroup,
  365. blocks_to_copy: Dict[int, List[int]],
  366. ) -> None:
  367. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  368. ret = self.block_manager.append_slot(seq)
  369. if ret is not None:
  370. src_block, dst_block = ret
  371. if src_block in blocks_to_copy:
  372. blocks_to_copy[src_block].append(dst_block)
  373. else:
  374. blocks_to_copy[src_block] = [dst_block]
  375. def _preempt(
  376. self,
  377. seq_group: SequenceGroup,
  378. blocks_to_swap_out: Dict[int, int],
  379. preemption_mode: Optional[PreemptionMode] = None,
  380. ) -> None:
  381. # If preemption mode is not specified, we determine the mode as follows:
  382. # We use recomputation by default since it incurs lower overhead than
  383. # swapping. However, when the sequence group has multiple sequences
  384. # (e.g., beam search), recomputation is not currently supported. In
  385. # such a case, we use swapping instead.
  386. # FIXME: This makes our scheduling policy a bit bizarre.
  387. # As swapped sequences are prioritized over waiting sequences,
  388. # sequence groups with multiple sequences are implicitly prioritized
  389. # over sequence groups with a single sequence.
  390. # TODO: Support recomputation for sequence groups with multiple
  391. # sequences. This may require a more sophisticated CUDA kernel.
  392. if preemption_mode is None:
  393. if seq_group.get_max_num_running_seqs() == 1:
  394. preemption_mode = PreemptionMode.RECOMPUTE
  395. else:
  396. preemption_mode = PreemptionMode.SWAP
  397. if preemption_mode == PreemptionMode.RECOMPUTE:
  398. self._preempt_by_recompute(seq_group)
  399. elif preemption_mode == PreemptionMode.SWAP:
  400. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  401. else:
  402. raise AssertionError("Invalid preemption mode.")
  403. def _preempt_by_recompute(
  404. self,
  405. seq_group: SequenceGroup,
  406. ) -> None:
  407. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  408. assert len(seqs) == 1
  409. for seq in seqs:
  410. seq.status = SequenceStatus.WAITING
  411. self.block_manager.free(seq)
  412. # NOTE: For FCFS, we insert the preempted sequence group to the front
  413. # of the waiting queue.
  414. self.waiting.appendleft(seq_group)
  415. def _preempt_by_swap(
  416. self,
  417. seq_group: SequenceGroup,
  418. blocks_to_swap_out: Dict[int, int],
  419. ) -> None:
  420. self._swap_out(seq_group, blocks_to_swap_out)
  421. self.swapped.append(seq_group)
  422. def _swap_in(
  423. self,
  424. seq_group: SequenceGroup,
  425. blocks_to_swap_in: Dict[int, int],
  426. ) -> None:
  427. mapping = self.block_manager.swap_in(seq_group)
  428. blocks_to_swap_in.update(mapping)
  429. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  430. seq.status = SequenceStatus.RUNNING
  431. def _swap_out(
  432. self,
  433. seq_group: SequenceGroup,
  434. blocks_to_swap_out: Dict[int, int],
  435. ) -> None:
  436. if not self.block_manager.can_swap_out(seq_group):
  437. # FIXME: Abort the sequence group instead of aborting the
  438. # entire engine.
  439. raise RuntimeError(
  440. "Aborted due to the lack of CPU swap space. Please increase "
  441. "the swap space to avoid this error.")
  442. mapping = self.block_manager.swap_out(seq_group)
  443. blocks_to_swap_out.update(mapping)
  444. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  445. seq.status = SequenceStatus.SWAPPED
  446. def mark_blocks_as_computed(self, seq_group: SequenceGroup):
  447. self.block_manager.mark_blocks_as_computed(seq_group)