scheduler.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212
  1. import enum
  2. import os
  3. import random
  4. import time
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
  8. from loguru import logger
  9. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  10. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  11. SequenceGroupMetadata, SequenceStatus)
  12. from aphrodite.lora.request import LoRARequest
  13. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  14. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  15. # Test-only. If configured, decode is preempted with
  16. # ARTIFICIAL_PREEMPTION_PROB% probability.
  17. ENABLE_ARTIFICIAL_PREEMPT = bool(
  18. os.getenv("APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
  19. ARTIFICIAL_PREEMPTION_PROB = 0.5
  20. ARTIFICIAL_PREEMPTION_MAX_CNT = 500
  21. class PreemptionMode(enum.Enum):
  22. """Preemption modes.
  23. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  24. and swap them back in when the sequences are resumed.
  25. 2. Recomputation: Discard the blocks of the preempted sequences and
  26. recompute them when the sequences are resumed, treating the sequences as
  27. new prompts.
  28. """
  29. SWAP = enum.auto()
  30. RECOMPUTE = enum.auto()
  31. @dataclass
  32. class SchedulingBudget:
  33. """The available slots for scheduling.
  34. TODO: Right now, the budget is request_id-aware meaning it can ignore
  35. budget update from the same request_id. It is because in normal scheduling
  36. path, we update RUNNING num_seqs ahead of time, meaning it could be
  37. updated more than once when scheduling RUNNING requests. Since this won't
  38. happen if we only have chunked prefill scheduling, we can remove this
  39. feature from the API when chunked prefill is enabled by default.
  40. """
  41. token_budget: int
  42. max_num_seqs: int
  43. _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
  44. _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
  45. _num_batched_tokens: int = 0
  46. _num_curr_seqs: int = 0
  47. def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
  48. assert num_new_tokens != 0
  49. assert num_new_seqs != 0
  50. return (self.num_batched_tokens + num_new_tokens <= self.token_budget
  51. and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
  52. def remaining_token_budget(self):
  53. return self.token_budget - self.num_batched_tokens
  54. def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
  55. if req_id in self._request_ids_num_batched_tokens:
  56. return
  57. self._request_ids_num_batched_tokens.add(req_id)
  58. self._num_batched_tokens += num_batched_tokens
  59. def subtract_num_batched_tokens(self, req_id: str,
  60. num_batched_tokens: int):
  61. if req_id in self._request_ids_num_batched_tokens:
  62. self._request_ids_num_batched_tokens.remove(req_id)
  63. self._num_batched_tokens -= num_batched_tokens
  64. def add_num_seqs(self, req_id: str, num_curr_seqs: int):
  65. if req_id in self._request_ids_num_curr_seqs:
  66. return
  67. self._request_ids_num_curr_seqs.add(req_id)
  68. self._num_curr_seqs += num_curr_seqs
  69. def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
  70. if req_id in self._request_ids_num_curr_seqs:
  71. self._request_ids_num_curr_seqs.remove(req_id)
  72. self._num_curr_seqs -= num_curr_seqs
  73. @property
  74. def num_batched_tokens(self):
  75. return self._num_batched_tokens
  76. @property
  77. def num_curr_seqs(self):
  78. return self._num_curr_seqs
  79. @dataclass
  80. class ScheduledSequenceGroup:
  81. # A sequence group that's scheduled.
  82. seq_group: SequenceGroup
  83. # The total chunk size (number of tokens) to process for next iteration.
  84. # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
  85. # chunked, it can be smaller than that.
  86. token_chunk_size: int
  87. @dataclass
  88. class SchedulerOutputs:
  89. """The scheduling decision made from a scheduler."""
  90. # Scheduled sequence groups.
  91. scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
  92. # Number of prefill groups scheduled.
  93. num_prefill_groups: int
  94. # Total number of batched tokens.
  95. num_batched_tokens: int
  96. # Blocks to swap in. List of CPU -> GPU block number.
  97. blocks_to_swap_in: List[Tuple[int, int]]
  98. # Blocks to swap out. List of GPU -> CPU block number.
  99. blocks_to_swap_out: List[Tuple[int, int]]
  100. # Blocks to copy. Source to dest block.
  101. blocks_to_copy: List[Tuple[int, int]]
  102. # Sequence groups that are going to be ignored.
  103. ignored_seq_groups: List[SequenceGroup]
  104. # The number of slots for lookahead decoding.
  105. num_lookahead_slots: int
  106. # The number of requests in the running queue
  107. running_queue_size: int
  108. preempted: int
  109. def __post_init__(self):
  110. # Swap in and swap out should never happen at the same time.
  111. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
  112. self.num_loras: int = len(self.lora_requests)
  113. if self.num_loras > 0:
  114. self._sort_by_lora_ids()
  115. self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
  116. def is_empty(self) -> bool:
  117. # NOTE: We do not consider the ignored sequence groups.
  118. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  119. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  120. def _sort_by_lora_ids(self):
  121. self.scheduled_seq_groups = sorted(
  122. self.scheduled_seq_groups,
  123. key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
  124. @property
  125. def lora_requests(self) -> Set[LoRARequest]:
  126. return {
  127. g.seq_group.lora_request
  128. for g in self.scheduled_seq_groups
  129. if g.seq_group.lora_request is not None
  130. }
  131. @property
  132. def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
  133. return {
  134. g.seq_group.prompt_adapter_request
  135. for g in self.scheduled_seq_groups
  136. if g.seq_group.prompt_adapter_request is not None
  137. }
  138. @dataclass
  139. class SchedulerRunningOutputs:
  140. """The requests that are scheduled from a running queue.
  141. Could contain prefill (prefill that's chunked) or decodes. If there's not
  142. enough memory, it can be preempted (for recompute) or swapped out.
  143. """
  144. # Selected sequences that are running and in a decoding phase.
  145. decode_seq_groups: List[SequenceGroup]
  146. # Selected sequences that are running and in a prefill phase.
  147. # I.e., it means the prefill has been chunked.
  148. prefill_seq_groups: List[SequenceGroup]
  149. # The preempted sequences.
  150. preempted: List[SequenceGroup]
  151. # Sequences that are swapped out.
  152. swapped_out: List[SequenceGroup]
  153. # The blocks to swap out.
  154. blocks_to_swap_out: List[Tuple[int, int]]
  155. # The blocks to copy.
  156. blocks_to_copy: List[Tuple[int, int]]
  157. # The number of slots for lookahead decoding.
  158. num_lookahead_slots: int
  159. @classmethod
  160. def create_empty(cls) -> "SchedulerRunningOutputs":
  161. return SchedulerRunningOutputs(
  162. decode_seq_groups=[],
  163. prefill_seq_groups=[],
  164. preempted=[],
  165. swapped_out=[],
  166. blocks_to_swap_out=[],
  167. blocks_to_copy=[],
  168. num_lookahead_slots=0,
  169. )
  170. @dataclass
  171. class SchedulerSwappedInOutputs:
  172. """The requests that are scheduled from a swap queue.
  173. Could contain prefill (prefill that's chunked) or decodes.
  174. """
  175. # Selected sequences that are going to be swapped in and is in a
  176. # decoding phase.
  177. decode_seq_groups: List[SequenceGroup]
  178. # Selected sequences that are going to be swapped in and in a prefill
  179. # phase. I.e., it means the prefill has been chunked.
  180. prefill_seq_groups: List[SequenceGroup]
  181. # The blocks to swap in.
  182. blocks_to_swap_in: List[Tuple[int, int]]
  183. # The blocks to copy.
  184. blocks_to_copy: List[Tuple[int, int]]
  185. # The number of slots for lookahead decoding.
  186. num_lookahead_slots: int
  187. # Infeasible sequence groups.
  188. infeasible_seq_groups: List[SequenceGroup]
  189. @classmethod
  190. def create_empty(cls) -> "SchedulerSwappedInOutputs":
  191. return SchedulerSwappedInOutputs(
  192. decode_seq_groups=[],
  193. prefill_seq_groups=[],
  194. blocks_to_swap_in=[],
  195. blocks_to_copy=[],
  196. num_lookahead_slots=0,
  197. infeasible_seq_groups=[],
  198. )
  199. @dataclass
  200. class SchedulerPrefillOutputs:
  201. """The requests that are scheduled from a waiting queue.
  202. Could contain a fresh prefill requests or preempted requests that need
  203. to be recomputed from scratch.
  204. """
  205. # Selected sequences for prefill.
  206. seq_groups: List[SequenceGroup]
  207. # Ignored sequence groups.
  208. ignored_seq_groups: List[SequenceGroup]
  209. num_lookahead_slots: int
  210. @classmethod
  211. def create_empty(cls) -> "SchedulerPrefillOutputs":
  212. return SchedulerPrefillOutputs(
  213. seq_groups=[],
  214. ignored_seq_groups=[],
  215. num_lookahead_slots=0,
  216. )
  217. class Scheduler:
  218. def __init__(
  219. self,
  220. scheduler_config: SchedulerConfig,
  221. cache_config: CacheConfig,
  222. lora_config: Optional[LoRAConfig],
  223. pipeline_parallel_size: int = 1,
  224. ) -> None:
  225. self.scheduler_config = scheduler_config
  226. self.cache_config = cache_config
  227. # Note for LoRA scheduling: the current policy is extremely
  228. # simple and NOT fair. It can lead to starvation of some
  229. # LoRAs. This should be improved in the future.
  230. self.lora_config = lora_config
  231. version = "v1"
  232. if self.scheduler_config.use_v2_block_manager:
  233. version = "v2"
  234. if self.scheduler_config.embedding_mode:
  235. version = "embedding"
  236. BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
  237. version)
  238. num_gpu_blocks = cache_config.num_gpu_blocks
  239. if num_gpu_blocks:
  240. num_gpu_blocks //= pipeline_parallel_size
  241. num_cpu_blocks = cache_config.num_cpu_blocks
  242. if num_cpu_blocks:
  243. num_cpu_blocks //= pipeline_parallel_size
  244. # Create the block space manager.
  245. self.block_manager = BlockSpaceManagerImpl(
  246. block_size=self.cache_config.block_size,
  247. num_gpu_blocks=num_gpu_blocks,
  248. num_cpu_blocks=num_cpu_blocks,
  249. sliding_window=self.cache_config.sliding_window,
  250. enable_caching=self.cache_config.enable_prefix_caching)
  251. # Sequence groups in the WAITING state.
  252. # Contain new prefill or preempted requests.
  253. self.waiting: Deque[SequenceGroup] = deque()
  254. # Sequence groups in the RUNNING state.
  255. # Contain decode requests.
  256. self.running: Deque[SequenceGroup] = deque()
  257. # Sequence groups in the SWAPPED state.
  258. # Contain decode requests that are swapped out.
  259. self.swapped: Deque[SequenceGroup] = deque()
  260. # Sequence groups finished requests ids since last step iteration.
  261. # It lets the model know that any state associated with these requests
  262. # can and must be released after the current step.
  263. # This is used to evict the finished requests from the Mamba cache.
  264. self._finished_requests_ids: List[str] = list()
  265. # Time at previous scheduling step
  266. self.prev_time = 0.0
  267. # Did we schedule a prompt at previous step?
  268. self.prev_prompt = False
  269. # Latency of the last prompt step
  270. self.last_prompt_latency = 0.0
  271. # preemption mode, RECOMPUTE or SWAP
  272. self.user_specified_preemption_mode = scheduler_config.preemption_mode
  273. # The following field is test-only. It is used to inject artificial
  274. # preemption.
  275. self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
  276. self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
  277. if self.enable_artificial_preemption
  278. else 0)
  279. self.num_cumulative_preemption: int = 0
  280. @property
  281. def lora_enabled(self) -> bool:
  282. return bool(self.lora_config)
  283. @property
  284. def num_decoding_tokens_per_seq(self) -> int:
  285. """The number of new tokens."""
  286. return 1
  287. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  288. # Add sequence groups to the waiting queue.
  289. self.waiting.append(seq_group)
  290. def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
  291. # Add sequence groups to the running queue.
  292. # Only for testing purposes.
  293. self.running.append(seq_group)
  294. def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
  295. # Add sequence groups to the swapped queue.
  296. # Only for testing purposes.
  297. self.swapped.append(seq_group)
  298. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  299. """Aborts a sequence group with the given ID.
  300. Check if the sequence group with the given ID
  301. is present in any of the state queue.
  302. If present, remove the sequence group from the state queue.
  303. Also, if any of the sequences in the sequence group is not finished,
  304. free the sequence with status `FINISHED_ABORTED`.
  305. Otherwise, do nothing.
  306. Args:
  307. request_id: The ID(s) of the sequence group to abort.
  308. """
  309. if isinstance(request_id, str):
  310. request_id = (request_id, )
  311. request_ids = set(request_id)
  312. for state_queue in [self.waiting, self.running, self.swapped]:
  313. aborted_groups: List[SequenceGroup] = []
  314. for seq_group in state_queue:
  315. if not request_ids:
  316. # Using 'break' here may add two extra iterations,
  317. # but is acceptable to reduce complexity.
  318. break
  319. if seq_group.request_id in request_ids:
  320. # Appending aborted group into pending list.
  321. aborted_groups.append(seq_group)
  322. request_ids.remove(seq_group.request_id)
  323. for aborted_group in aborted_groups:
  324. # Remove the sequence group from the state queue.
  325. state_queue.remove(aborted_group)
  326. # Remove the aborted request from the Mamba cache.
  327. self._finished_requests_ids.append(aborted_group.request_id)
  328. for seq in aborted_group.get_seqs():
  329. if seq.is_finished():
  330. continue
  331. seq.status = SequenceStatus.FINISHED_ABORTED
  332. self.free_seq(seq)
  333. def has_unfinished_seqs(self) -> bool:
  334. return len(self.waiting) != 0 or len(self.running) != 0 or len(
  335. self.swapped) != 0
  336. def get_num_unfinished_seq_groups(self) -> int:
  337. return len(self.waiting) + len(self.running) + len(self.swapped)
  338. def get_and_reset_finished_requests_ids(self) -> List[str]:
  339. """Flushes the list of request ids of previously finished seq_groups."""
  340. finished_requests_ids = self._finished_requests_ids
  341. self._finished_requests_ids = list()
  342. return finished_requests_ids
  343. def _schedule_running(
  344. self,
  345. budget: SchedulingBudget,
  346. curr_loras: Optional[Set[int]],
  347. enable_chunking: bool = False,
  348. ) -> SchedulerRunningOutputs:
  349. """Schedule sequence groups that are running.
  350. Running queue should include decode and chunked prefill requests.
  351. Args:
  352. budget: The scheduling budget. The argument is in-place updated
  353. when any decodes are preempted.
  354. curr_loras: Currently batched lora request ids. The argument is
  355. in-place updated when any decodes are preempted.
  356. enable_chunking: If True, seq group can be chunked and only a
  357. chunked number of tokens are scheduled if
  358. `budget.num_batched_tokens` has not enough capacity to schedule
  359. all tokens.
  360. Returns:
  361. SchedulerRunningOutputs.
  362. """
  363. # Blocks that need to be swapped or copied before model execution.
  364. blocks_to_swap_out: List[Tuple[int, int]] = []
  365. blocks_to_copy: List[Tuple[int, int]] = []
  366. decode_seq_groups: List[ScheduledSequenceGroup] = []
  367. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  368. preempted: List[SequenceGroup] = []
  369. swapped_out: List[SequenceGroup] = []
  370. # NOTE: Preemption happens only when there is no available slot
  371. # to keep all the sequence groups in the RUNNING state.
  372. running_queue = self.running
  373. while running_queue:
  374. seq_group = running_queue[0]
  375. num_running_tokens = self._get_num_new_tokens(
  376. seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
  377. if num_running_tokens == 0:
  378. break
  379. running_queue.popleft()
  380. while not self._can_append_slots(seq_group):
  381. budget.subtract_num_batched_tokens(seq_group.request_id,
  382. num_running_tokens)
  383. num_running_seqs = seq_group.get_max_num_running_seqs()
  384. budget.subtract_num_seqs(seq_group.request_id,
  385. num_running_seqs)
  386. if (curr_loras is not None and seq_group.lora_int_id > 0
  387. and seq_group.lora_int_id in curr_loras):
  388. curr_loras.remove(seq_group.lora_int_id)
  389. if running_queue:
  390. # Preempt the lowest-priority sequence groups.
  391. victim_seq_group = running_queue.pop()
  392. preempted_mode = self._preempt(victim_seq_group,
  393. blocks_to_swap_out)
  394. if preempted_mode == PreemptionMode.RECOMPUTE:
  395. preempted.append(victim_seq_group)
  396. else:
  397. swapped_out.append(victim_seq_group)
  398. else:
  399. # No other sequence groups can be preempted.
  400. # Preempt the current sequence group.
  401. preempted_mode = self._preempt(seq_group,
  402. blocks_to_swap_out)
  403. if preempted_mode == PreemptionMode.RECOMPUTE:
  404. preempted.append(seq_group)
  405. else:
  406. swapped_out.append(seq_group)
  407. break
  408. else:
  409. self._append_slots(seq_group, blocks_to_copy)
  410. is_prefill = seq_group.is_prefill()
  411. if is_prefill:
  412. prefill_seq_groups.append(
  413. ScheduledSequenceGroup(
  414. seq_group=seq_group,
  415. token_chunk_size=num_running_tokens))
  416. else:
  417. decode_seq_groups.append(
  418. ScheduledSequenceGroup(seq_group=seq_group,
  419. token_chunk_size=1))
  420. budget.add_num_batched_tokens(seq_group.request_id,
  421. num_running_tokens)
  422. # OPTIMIZATION: Note that get_max_num_running_seqs is
  423. # expensive. For the default scheduling chase where
  424. # enable_chunking is False, num_seqs are updated before running
  425. # this method, so we don't have to update it again here.
  426. if enable_chunking:
  427. num_running_seqs = seq_group.get_max_num_running_seqs()
  428. budget.add_num_seqs(seq_group.request_id, num_running_seqs)
  429. if curr_loras is not None and seq_group.lora_int_id > 0:
  430. curr_loras.add(seq_group.lora_int_id)
  431. return SchedulerRunningOutputs(
  432. decode_seq_groups=decode_seq_groups,
  433. prefill_seq_groups=prefill_seq_groups,
  434. preempted=preempted,
  435. swapped_out=swapped_out,
  436. blocks_to_swap_out=blocks_to_swap_out,
  437. blocks_to_copy=blocks_to_copy,
  438. num_lookahead_slots=self._get_num_lookahead_slots(
  439. is_prefill=False))
  440. def _schedule_swapped(
  441. self,
  442. budget: SchedulingBudget,
  443. curr_loras: Optional[Set[int]],
  444. enable_chunking: bool = False,
  445. ) -> SchedulerSwappedInOutputs:
  446. """Schedule sequence groups that are swapped out.
  447. It schedules swapped requests as long as it fits `budget` and
  448. curr_loras <= max_lora from the scheduling config. The input arguments
  449. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  450. Args:
  451. budget: The scheduling budget. The argument is in-place updated
  452. when any requests are swapped in.
  453. curr_loras: Currently batched lora request ids. The argument is
  454. in-place updated when any requests are swapped in.
  455. enable_chunking: If True, seq group can be chunked and only a
  456. chunked number of tokens are scheduled if
  457. `budget.num_batched_tokens` has not enough capacity to schedule
  458. all tokens.
  459. Returns:
  460. SchedulerSwappedInOutputs.
  461. """
  462. # Blocks that need to be swapped or copied before model execution.
  463. blocks_to_swap_in: List[Tuple[int, int]] = []
  464. blocks_to_copy: List[Tuple[int, int]] = []
  465. decode_seq_groups: List[ScheduledSequenceGroup] = []
  466. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  467. infeasible_seq_groups: List[SequenceGroup] = []
  468. swapped_queue = self.swapped
  469. leftover_swapped: Deque[SequenceGroup] = deque()
  470. while swapped_queue:
  471. seq_group = swapped_queue[0]
  472. # If the sequence group cannot be swapped in, stop.
  473. is_prefill = seq_group.is_prefill()
  474. alloc_status = self.block_manager.can_swap_in(
  475. seq_group, self._get_num_lookahead_slots(is_prefill))
  476. if alloc_status == AllocStatus.LATER:
  477. break
  478. elif alloc_status == AllocStatus.NEVER:
  479. logger.warning(f"Failing the request {seq_group.request_id} "
  480. "because there's not enough kv cache blocks to "
  481. "run the entire sequence.")
  482. for seq in seq_group.get_seqs():
  483. seq.status = SequenceStatus.FINISHED_IGNORED
  484. infeasible_seq_groups.append(seq_group)
  485. swapped_queue.popleft()
  486. continue
  487. lora_int_id = 0
  488. if self.lora_enabled:
  489. lora_int_id = seq_group.lora_int_id
  490. assert curr_loras is not None
  491. assert self.lora_config is not None
  492. if (lora_int_id > 0 and (lora_int_id not in curr_loras)
  493. and len(curr_loras) >= self.lora_config.max_loras):
  494. # We don't have a space for another LoRA, so
  495. # we ignore this request for now.
  496. leftover_swapped.appendleft(seq_group)
  497. swapped_queue.popleft()
  498. continue
  499. # The total number of sequences in the RUNNING state should not
  500. # exceed the maximum number of sequences.
  501. num_new_seqs = seq_group.get_max_num_running_seqs()
  502. num_new_tokens = self._get_num_new_tokens(seq_group,
  503. SequenceStatus.SWAPPED,
  504. enable_chunking, budget)
  505. if (num_new_tokens == 0
  506. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  507. num_new_seqs=num_new_seqs)):
  508. break
  509. if lora_int_id > 0 and curr_loras is not None:
  510. curr_loras.add(lora_int_id)
  511. swapped_queue.popleft()
  512. self._swap_in(seq_group, blocks_to_swap_in)
  513. self._append_slots(seq_group, blocks_to_copy)
  514. is_prefill = seq_group.is_prefill()
  515. if is_prefill:
  516. prefill_seq_groups.append(
  517. ScheduledSequenceGroup(seq_group,
  518. token_chunk_size=num_new_tokens))
  519. else:
  520. decode_seq_groups.append(
  521. ScheduledSequenceGroup(seq_group, token_chunk_size=1))
  522. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  523. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  524. swapped_queue.extendleft(leftover_swapped)
  525. return SchedulerSwappedInOutputs(
  526. decode_seq_groups=decode_seq_groups,
  527. prefill_seq_groups=prefill_seq_groups,
  528. blocks_to_swap_in=blocks_to_swap_in,
  529. blocks_to_copy=blocks_to_copy,
  530. num_lookahead_slots=self._get_num_lookahead_slots(
  531. is_prefill=False),
  532. infeasible_seq_groups=infeasible_seq_groups,
  533. )
  534. def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
  535. if self.scheduler_config.chunked_prefill_enabled:
  536. prompt_limit = self.scheduler_config.max_model_len
  537. else:
  538. prompt_limit = min(self.scheduler_config.max_model_len,
  539. self.scheduler_config.max_num_batched_tokens)
  540. # Model is fine tuned with long context. Return the fine tuned max_len.
  541. if (seq_group.lora_request
  542. and seq_group.lora_request.long_lora_max_len):
  543. assert prompt_limit <= seq_group.lora_request.long_lora_max_len
  544. return seq_group.lora_request.long_lora_max_len
  545. else:
  546. return prompt_limit
  547. def _schedule_prefills(
  548. self,
  549. budget: SchedulingBudget,
  550. curr_loras: Optional[Set[int]],
  551. enable_chunking: bool = False,
  552. ) -> SchedulerPrefillOutputs:
  553. """Schedule sequence groups that are in prefill stage.
  554. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
  555. as a new prefill (that starts from beginning -> most recently generated
  556. tokens).
  557. It schedules waiting requests as long as it fits `budget` and
  558. curr_loras <= max_lora from the scheduling config. The input arguments
  559. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  560. Args:
  561. budget: The scheduling budget. The argument is in-place updated
  562. when any requests are scheduled.
  563. curr_loras: Currently batched lora request ids. The argument is
  564. in-place updated when any requests are scheduled.
  565. enable_chunking: If True, seq group can be chunked and only a
  566. chunked number of tokens are scheduled if
  567. `budget.num_batched_tokens` has not enough capacity to schedule
  568. all tokens.
  569. Returns:
  570. SchedulerSwappedInOutputs.
  571. """
  572. ignored_seq_groups: List[SequenceGroup] = []
  573. seq_groups: List[SequenceGroup] = []
  574. waiting_queue = self.waiting
  575. leftover_waiting_sequences: Deque[SequenceGroup] = deque()
  576. while self._passed_delay(time.time()) and waiting_queue:
  577. seq_group = waiting_queue[0]
  578. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  579. assert len(waiting_seqs) == 1, (
  580. "Waiting sequence group should have only one prompt "
  581. "sequence.")
  582. num_new_tokens = self._get_num_new_tokens(seq_group,
  583. SequenceStatus.WAITING,
  584. enable_chunking, budget)
  585. if not enable_chunking:
  586. num_prompt_tokens = waiting_seqs[0].get_len()
  587. assert num_new_tokens == num_prompt_tokens
  588. prompt_limit = self._get_prompt_limit(seq_group)
  589. if num_new_tokens > prompt_limit:
  590. logger.warning(
  591. "Input prompt (%d tokens) is too long"
  592. " and exceeds limit of %d", num_new_tokens, prompt_limit)
  593. for seq in waiting_seqs:
  594. seq.status = SequenceStatus.FINISHED_IGNORED
  595. ignored_seq_groups.append(seq_group)
  596. waiting_queue.popleft()
  597. continue
  598. # If the sequence group cannot be allocated, stop.
  599. can_allocate = self.block_manager.can_allocate(seq_group)
  600. if can_allocate == AllocStatus.LATER:
  601. break
  602. elif can_allocate == AllocStatus.NEVER:
  603. logger.warning(
  604. "Input prompt (%d tokens) is too long"
  605. " and exceeds the capacity of block_manager",
  606. num_new_tokens)
  607. for seq in waiting_seqs:
  608. seq.status = SequenceStatus.FINISHED_IGNORED
  609. ignored_seq_groups.append(seq_group)
  610. waiting_queue.popleft()
  611. continue
  612. lora_int_id = 0
  613. if self.lora_enabled:
  614. lora_int_id = seq_group.lora_int_id
  615. assert curr_loras is not None
  616. assert self.lora_config is not None
  617. if (self.lora_enabled and lora_int_id > 0
  618. and lora_int_id not in curr_loras
  619. and len(curr_loras) >= self.lora_config.max_loras):
  620. # We don't have a space for another LoRA, so
  621. # we ignore this request for now.
  622. leftover_waiting_sequences.appendleft(seq_group)
  623. waiting_queue.popleft()
  624. continue
  625. num_new_seqs = seq_group.get_max_num_running_seqs()
  626. if (num_new_tokens == 0
  627. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  628. num_new_seqs=num_new_seqs)):
  629. break
  630. # Can schedule this request.
  631. if curr_loras is not None and lora_int_id > 0:
  632. curr_loras.add(lora_int_id)
  633. waiting_queue.popleft()
  634. self._allocate_and_set_running(seq_group)
  635. seq_groups.append(
  636. ScheduledSequenceGroup(seq_group=seq_group,
  637. token_chunk_size=num_new_tokens))
  638. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  639. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  640. # Queue requests that couldn't be scheduled.
  641. waiting_queue.extendleft(leftover_waiting_sequences)
  642. if len(seq_groups) > 0:
  643. self.prev_prompt = True
  644. return SchedulerPrefillOutputs(
  645. seq_groups=seq_groups,
  646. ignored_seq_groups=ignored_seq_groups,
  647. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
  648. def _schedule_default(self) -> SchedulerOutputs:
  649. """Schedule queued requests.
  650. The current policy is designed to optimize the throughput. First,
  651. it batches as many prefill requests as possible. And it schedules
  652. decodes. If there's a pressure on GPU memory, decode requests can
  653. be swapped or preempted.
  654. """
  655. # Include running requests to the budget.
  656. budget = SchedulingBudget(
  657. token_budget=self.scheduler_config.max_num_batched_tokens,
  658. max_num_seqs=self.scheduler_config.max_num_seqs,
  659. )
  660. # Make sure we include num running seqs before scheduling prefill,
  661. # so that we don't schedule beyond max_num_seqs for prefill.
  662. for seq_group in self.running:
  663. budget.add_num_seqs(seq_group.request_id,
  664. seq_group.get_max_num_running_seqs())
  665. curr_loras = set(
  666. seq_group.lora_int_id for seq_group in self.running
  667. if seq_group.lora_int_id > 0) if self.lora_enabled else None
  668. prefills = SchedulerPrefillOutputs.create_empty()
  669. running_scheduled = SchedulerRunningOutputs.create_empty()
  670. swapped_in = SchedulerSwappedInOutputs.create_empty()
  671. # If any requests are swapped, prioritized swapped requests.
  672. if not self.swapped:
  673. prefills = self._schedule_prefills(budget,
  674. curr_loras,
  675. enable_chunking=False)
  676. # Don't schedule decodes if prefills are scheduled.
  677. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
  678. # only contains decode requests, not chunked prefills.
  679. if len(prefills.seq_groups) == 0:
  680. running_scheduled = self._schedule_running(budget,
  681. curr_loras,
  682. enable_chunking=False)
  683. # If any sequence group is preempted, do not swap in any sequence
  684. # group. because it means there's no slot for new running requests.
  685. if len(running_scheduled.preempted) + len(
  686. running_scheduled.swapped_out) == 0:
  687. swapped_in = self._schedule_swapped(budget, curr_loras)
  688. assert (budget.num_batched_tokens <=
  689. self.scheduler_config.max_num_batched_tokens)
  690. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  691. # Update waiting requests.
  692. self.waiting.extendleft(running_scheduled.preempted)
  693. # Update new running requests.
  694. self.running.extend([s.seq_group for s in prefills.seq_groups])
  695. self.running.extend(
  696. [s.seq_group for s in running_scheduled.decode_seq_groups])
  697. self.running.extend(
  698. [s.seq_group for s in swapped_in.decode_seq_groups])
  699. # Update swapped requests.
  700. self.swapped.extend(running_scheduled.swapped_out)
  701. preempted = (len(running_scheduled.preempted) +
  702. len(running_scheduled.swapped_out))
  703. # There should be no prefill from running queue because this policy
  704. # doesn't allow chunked prefills.
  705. assert len(running_scheduled.prefill_seq_groups) == 0
  706. assert len(swapped_in.prefill_seq_groups) == 0
  707. return SchedulerOutputs(
  708. scheduled_seq_groups=(prefills.seq_groups +
  709. running_scheduled.decode_seq_groups +
  710. swapped_in.decode_seq_groups),
  711. num_prefill_groups=len(prefills.seq_groups),
  712. num_batched_tokens=budget.num_batched_tokens,
  713. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  714. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  715. blocks_to_copy=running_scheduled.blocks_to_copy +
  716. swapped_in.blocks_to_copy,
  717. ignored_seq_groups=prefills.ignored_seq_groups +
  718. swapped_in.infeasible_seq_groups,
  719. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  720. running_queue_size=len(self.running),
  721. preempted=preempted,
  722. )
  723. def _schedule_chunked_prefill(self):
  724. """Schedule queued requests.
  725. Chunked prefill allows to chunk prefill requests, batch them together
  726. with decode requests. This policy 1. schedule as many decoding requests
  727. as possible. 2. schedule chunked prefill requests that are not
  728. finished. 3. schedule swapped request. 4. schedule new prefill
  729. requests.
  730. The policy can sustain the high GPU utilization because it can put
  731. prefill and decodes requests to the same batch, while it improves
  732. inter token latency because decodes requests don't need to blocked
  733. by prefill requests.
  734. """
  735. budget = SchedulingBudget(
  736. token_budget=self.scheduler_config.max_num_batched_tokens,
  737. max_num_seqs=self.scheduler_config.max_num_seqs,
  738. )
  739. curr_loras: Set[int] = set()
  740. prefills = SchedulerPrefillOutputs.create_empty()
  741. swapped_in = SchedulerSwappedInOutputs.create_empty()
  742. # Decoding should be always scheduled first by fcfs.
  743. running_scheduled = self._schedule_running(budget,
  744. curr_loras,
  745. enable_chunking=True)
  746. # Schedule swapped out requests.
  747. # If preemption happens, it means we don't have space for swap-in.
  748. if len(running_scheduled.preempted) + len(
  749. running_scheduled.swapped_out) == 0:
  750. swapped_in = self._schedule_swapped(budget, curr_loras)
  751. # Schedule new prefills.
  752. prefills = self._schedule_prefills(budget,
  753. curr_loras,
  754. enable_chunking=True)
  755. assert (budget.num_batched_tokens <=
  756. self.scheduler_config.max_num_batched_tokens)
  757. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  758. # Update waiting requests.
  759. self.waiting.extendleft(running_scheduled.preempted)
  760. # Update new running requests.
  761. self.running.extend([s.seq_group for s in prefills.seq_groups])
  762. self.running.extend(
  763. [s.seq_group for s in running_scheduled.decode_seq_groups])
  764. self.running.extend(
  765. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  766. self.running.extend(
  767. [s.seq_group for s in swapped_in.decode_seq_groups])
  768. self.running.extend(
  769. [s.seq_group for s in swapped_in.prefill_seq_groups])
  770. # Update swapped requests.
  771. self.swapped.extend(running_scheduled.swapped_out)
  772. return SchedulerOutputs(
  773. scheduled_seq_groups=(prefills.seq_groups +
  774. running_scheduled.prefill_seq_groups +
  775. swapped_in.prefill_seq_groups +
  776. running_scheduled.decode_seq_groups +
  777. swapped_in.decode_seq_groups),
  778. num_prefill_groups=(len(prefills.seq_groups) +
  779. len(swapped_in.prefill_seq_groups) +
  780. len(running_scheduled.prefill_seq_groups)),
  781. num_batched_tokens=budget.num_batched_tokens,
  782. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  783. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  784. blocks_to_copy=running_scheduled.blocks_to_copy +
  785. swapped_in.blocks_to_copy,
  786. ignored_seq_groups=prefills.ignored_seq_groups +
  787. swapped_in.infeasible_seq_groups,
  788. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  789. running_queue_size=len(self.running),
  790. preempted=(len(running_scheduled.preempted) +
  791. len(running_scheduled.swapped_out)),
  792. )
  793. def _schedule(self) -> SchedulerOutputs:
  794. """Schedule queued requests."""
  795. if self.scheduler_config.chunked_prefill_enabled:
  796. return self._schedule_chunked_prefill()
  797. else:
  798. return self._schedule_default()
  799. def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
  800. """Determine whether or not we have enough space in the KV cache to
  801. continue generation of the sequence group.
  802. """
  803. # It is True only for testing case to trigger artificial preemption.
  804. if (self.enable_artificial_preemption
  805. and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
  806. and self.artificial_preempt_cnt > 0):
  807. self.artificial_preempt_cnt -= 1
  808. return False
  809. # Appending slots only occurs in decoding.
  810. is_prefill = False
  811. return self.block_manager.can_append_slots(
  812. seq_group=seq_group,
  813. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  814. )
  815. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  816. # Schedule sequence groups.
  817. # This function call changes the internal states of the scheduler
  818. # such as self.running, self.swapped, and self.waiting.
  819. scheduler_outputs = self._schedule()
  820. now = time.time()
  821. # Create input data structures.
  822. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  823. for i, scheduled_seq_group in enumerate(
  824. scheduler_outputs.scheduled_seq_groups):
  825. seq_group = scheduled_seq_group.seq_group
  826. token_chunk_size = scheduled_seq_group.token_chunk_size
  827. seq_group.maybe_set_first_scheduled_time(now)
  828. # seq_id -> SequenceData
  829. seq_data: Dict[int, SequenceData] = {}
  830. # seq_id -> physical block numbers
  831. block_tables: Dict[int, List[int]] = {}
  832. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  833. seq_id = seq.seq_id
  834. seq_data[seq_id] = seq.data
  835. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  836. self.block_manager.access_all_blocks_in_seq(seq, now)
  837. common_computed_block_nums = (
  838. self.block_manager.get_common_computed_block_ids(
  839. seq_group.get_seqs(status=SequenceStatus.RUNNING)))
  840. do_sample = True
  841. if seq_group.is_prefill():
  842. seqs = seq_group.get_seqs()
  843. # Prefill has only 1 sequence.
  844. assert len(seqs) == 1
  845. # In the next iteration, all prompt tokens are not computed.
  846. # It means the prefill is chunked, and we don't need sampling.
  847. # NOTE: We use get_len instead of get_prompt_len because when
  848. # a sequence is preempted, prefill includes previous generated
  849. # output tokens.
  850. if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
  851. seqs[0].data.get_len()):
  852. do_sample = False
  853. # It assumes the scheduled_seq_groups is ordered by
  854. # prefill < decoding.
  855. is_prompt = seq_group.is_prefill()
  856. seq_group_metadata = SequenceGroupMetadata(
  857. request_id=seq_group.request_id,
  858. is_prompt=is_prompt,
  859. seq_data=seq_data,
  860. sampling_params=seq_group.sampling_params,
  861. block_tables=block_tables,
  862. do_sample=do_sample,
  863. pooling_params=seq_group.pooling_params,
  864. token_chunk_size=token_chunk_size,
  865. lora_request=seq_group.lora_request,
  866. computed_block_nums=common_computed_block_nums,
  867. # `multi_modal_data` will only be present for the 1st comm
  868. # between engine and worker.
  869. # the subsequent comms can still use delta, but
  870. # `multi_modal_data` will be None.
  871. multi_modal_data=seq_group.multi_modal_data
  872. if scheduler_outputs.num_prefill_groups > 0 else None,
  873. prompt_adapter_request=seq_group.prompt_adapter_request,
  874. )
  875. seq_group_metadata_list.append(seq_group_metadata)
  876. # Now that the batch has been created, we can assume all blocks in the
  877. # batch will have been computed before the next scheduling invocation.
  878. # This is because the engine assumes that a failure in model execution
  879. # will crash the Aphrodite instance / will not retry.
  880. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
  881. self.block_manager.mark_blocks_as_computed(
  882. scheduled_seq_group.seq_group)
  883. return seq_group_metadata_list, scheduler_outputs
  884. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  885. self.block_manager.fork(parent_seq, child_seq)
  886. def free_seq(self, seq: Sequence) -> None:
  887. """Free a sequence from a block table."""
  888. self.block_manager.free(seq)
  889. def free_finished_seq_groups(self) -> None:
  890. remaining: Deque[SequenceGroup] = deque()
  891. for seq_group in self.running:
  892. if seq_group.is_finished():
  893. # Add the finished requests to the finished requests list.
  894. # This list will be used to update the Mamba cache in the
  895. # next step.
  896. self._finished_requests_ids.append(seq_group.request_id)
  897. else:
  898. remaining.append(seq_group)
  899. self.running = remaining
  900. def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
  901. self.block_manager.allocate(seq_group)
  902. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  903. seq.status = SequenceStatus.RUNNING
  904. def _append_slots(
  905. self,
  906. seq_group: SequenceGroup,
  907. blocks_to_copy: List[Tuple[int, int]],
  908. ) -> None:
  909. """Appends new slots to the sequences in the given sequence group.
  910. Args:
  911. seq_group (SequenceGroup): The sequence group containing the
  912. sequences to append slots to.
  913. blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
  914. ints, the first int is the source block index, and the second
  915. int is the destination block index. This list is updated with
  916. the new source and destination block indices for the appended
  917. slots.
  918. """
  919. num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
  920. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  921. cows = self.block_manager.append_slots(seq, num_lookahead_slots)
  922. blocks_to_copy.extend(cows)
  923. def _preempt(
  924. self,
  925. seq_group: SequenceGroup,
  926. blocks_to_swap_out: List[Tuple[int, int]],
  927. preemption_mode: Optional[PreemptionMode] = None,
  928. ) -> PreemptionMode:
  929. # If preemption mode is not specified, we determine the mode as follows:
  930. # We use recomputation by default since it incurs lower overhead than
  931. # swapping. However, when the sequence group has multiple sequences
  932. # (e.g., beam search), recomputation is not currently supported. In
  933. # such a case, we use swapping instead.
  934. # FIXME: This makes our scheduling policy a bit bizarre.
  935. # As swapped sequences are prioritized over waiting sequences,
  936. # sequence groups with multiple sequences are implicitly prioritized
  937. # over sequence groups with a single sequence.
  938. # TODO: Support recomputation for sequence groups with multiple
  939. # sequences. This may require a more sophisticated CUDA kernel.
  940. if self.user_specified_preemption_mode is None:
  941. if seq_group.get_max_num_running_seqs() == 1:
  942. preemption_mode = PreemptionMode.RECOMPUTE
  943. else:
  944. preemption_mode = PreemptionMode.SWAP
  945. elif self.user_specified_preemption_mode == "swap":
  946. preemption_mode = PreemptionMode.SWAP
  947. else:
  948. preemption_mode = PreemptionMode.RECOMPUTE
  949. if self.num_cumulative_preemption % 50 == 0:
  950. logger.warning(
  951. f"Sequence group {seq_group.request_id} is preempted by "
  952. f"{preemption_mode} mode because there is "
  953. "not enough KV cache space. This can affect the end-to-end "
  954. "performance. Increase gpu_memory_utilization or "
  955. "tensor_parallel_size to provide more KV cache memory. "
  956. "total_num_cumulative_preemption="
  957. f"{self.num_cumulative_preemption + 1}")
  958. self.num_cumulative_preemption += 1
  959. if preemption_mode == PreemptionMode.RECOMPUTE:
  960. self._preempt_by_recompute(seq_group)
  961. elif preemption_mode == PreemptionMode.SWAP:
  962. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  963. else:
  964. raise AssertionError("Invalid preemption mode.")
  965. return preemption_mode
  966. def _preempt_by_recompute(
  967. self,
  968. seq_group: SequenceGroup,
  969. ) -> None:
  970. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  971. assert len(seqs) == 1
  972. for seq in seqs:
  973. seq.status = SequenceStatus.WAITING
  974. self.free_seq(seq)
  975. seq.reset_state_for_recompute()
  976. def _preempt_by_swap(
  977. self,
  978. seq_group: SequenceGroup,
  979. blocks_to_swap_out: List[Tuple[int, int]],
  980. ) -> None:
  981. self._swap_out(seq_group, blocks_to_swap_out)
  982. def _swap_in(
  983. self,
  984. seq_group: SequenceGroup,
  985. blocks_to_swap_in: List[Tuple[int, int]],
  986. ) -> None:
  987. mapping = self.block_manager.swap_in(seq_group)
  988. blocks_to_swap_in.extend(mapping)
  989. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  990. seq.status = SequenceStatus.RUNNING
  991. def _swap_out(
  992. self,
  993. seq_group: SequenceGroup,
  994. blocks_to_swap_out: List[Tuple[int, int]],
  995. ) -> None:
  996. if not self.block_manager.can_swap_out(seq_group):
  997. # FIXME: Abort the sequence group instead of aborting the
  998. # entire engine.
  999. raise RuntimeError(
  1000. "Aborted due to the lack of CPU swap space. Please increase "
  1001. "the swap space to avoid this error.")
  1002. mapping = self.block_manager.swap_out(seq_group)
  1003. blocks_to_swap_out.extend(mapping)
  1004. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  1005. seq.status = SequenceStatus.SWAPPED
  1006. def _passed_delay(self, now: float) -> bool:
  1007. if self.prev_prompt:
  1008. self.last_prompt_latency = now - self.prev_time
  1009. self.prev_time, self.prev_prompt = now, False
  1010. # Delay scheduling prompts to let waiting queue fill up
  1011. if self.scheduler_config.delay_factor > 0 and self.waiting:
  1012. earliest_arrival_time = min(
  1013. [e.metrics.arrival_time for e in self.waiting])
  1014. passed_delay = (
  1015. (now - earliest_arrival_time) >
  1016. (self.scheduler_config.delay_factor * self.last_prompt_latency)
  1017. or not self.running)
  1018. else:
  1019. passed_delay = True
  1020. return passed_delay
  1021. def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
  1022. """The number of slots to allocate per sequence per step, beyond known
  1023. token ids. Speculative decoding uses these slots to store KV activations
  1024. of tokens which may or may not be accepted.
  1025. Speculative decoding does not yet support prefill, so we do not perform
  1026. lookahead allocation for prefill.
  1027. """
  1028. if is_prefill:
  1029. return 0
  1030. return self.scheduler_config.num_lookahead_slots
  1031. def _get_num_new_tokens(self, seq_group: SequenceGroup,
  1032. status: SequenceStatus, enable_chunking: bool,
  1033. budget: SchedulingBudget) -> int:
  1034. """Get the next new tokens to compute for a given sequence group
  1035. that's in a given `status`.
  1036. The API could chunk the number of tokens to compute based on `budget`
  1037. if `enable_chunking` is True. If a sequence group has multiple
  1038. sequences (e.g., running beam search), it means it is in decoding
  1039. phase, so chunking doesn't happen.
  1040. Returns 0 if the new token cannot be computed due to token budget.
  1041. """
  1042. num_new_tokens = 0
  1043. seqs = seq_group.get_seqs(status=status)
  1044. for seq in seqs:
  1045. num_new_tokens += seq.get_num_new_tokens()
  1046. assert num_new_tokens > 0
  1047. # Chunk if a running request cannot fit in.
  1048. # If number of seq > 1, it means it is doing beam search in a
  1049. # decode phase. Do not chunk in that case.
  1050. if enable_chunking and len(seqs) == 1:
  1051. num_new_tokens = min(num_new_tokens,
  1052. budget.remaining_token_budget())
  1053. return num_new_tokens