scheduler.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157
  1. import enum
  2. import os
  3. import random
  4. import time
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
  8. from loguru import logger
  9. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  10. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  11. SequenceGroupMetadata, SequenceStatus)
  12. from aphrodite.lora.request import LoRARequest
  13. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  14. from aphrodite.processing.policy import Policy, PolicyFactory
  15. # Test-only. If configured, decode is preempted with
  16. # ARTIFICIAL_PREEMPTION_PROB% probability.
  17. ENABLE_ARTIFICIAL_PREEMPT = bool(
  18. os.getenv("APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
  19. ARTIFICIAL_PREEMPTION_PROB = 0.5
  20. ARTIFICIAL_PREEMPTION_MAX_CNT = 500
  21. class PreemptionMode(enum.Enum):
  22. """Preemption modes.
  23. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  24. and swap them back in when the sequences are resumed.
  25. 2. Recomputation: Discard the blocks of the preempted sequences and
  26. recompute them when the sequences are resumed, treating the sequences as
  27. new prompts.
  28. """
  29. SWAP = enum.auto()
  30. RECOMPUTE = enum.auto()
  31. @dataclass
  32. class SchedulingBudget:
  33. """The available slots for scheduling.
  34. TODO: Right now, the budget is request_id-aware meaning it can ignore
  35. budget update from the same request_id. It is because in normal scheduling
  36. path, we update RUNNING num_seqs ahead of time, meaning it could be
  37. updated more than once when scheduling RUNNING requests. Since this won't
  38. happen if we only have chunked prefill scheduling, we can remove this
  39. feature from the API when chunked prefill is enabled by default.
  40. """
  41. token_budget: int
  42. max_num_seqs: int
  43. _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
  44. _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
  45. _num_batched_tokens: int = 0
  46. _num_curr_seqs: int = 0
  47. def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
  48. assert num_new_tokens != 0
  49. assert num_new_seqs != 0
  50. return (self.num_batched_tokens + num_new_tokens <= self.token_budget
  51. and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
  52. def remaining_token_budget(self):
  53. return self.token_budget - self.num_batched_tokens
  54. def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
  55. if req_id in self._requeset_ids_num_batched_tokens:
  56. return
  57. self._requeset_ids_num_batched_tokens.add(req_id)
  58. self._num_batched_tokens += num_batched_tokens
  59. def subtract_num_batched_tokens(self, req_id: str,
  60. num_batched_tokens: int):
  61. if req_id in self._requeset_ids_num_batched_tokens:
  62. self._requeset_ids_num_batched_tokens.remove(req_id)
  63. self._num_batched_tokens -= num_batched_tokens
  64. def add_num_seqs(self, req_id: str, num_curr_seqs: int):
  65. if req_id in self._requeset_ids_num_curr_seqs:
  66. return
  67. self._requeset_ids_num_curr_seqs.add(req_id)
  68. self._num_curr_seqs += num_curr_seqs
  69. def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
  70. if req_id in self._requeset_ids_num_curr_seqs:
  71. self._requeset_ids_num_curr_seqs.remove(req_id)
  72. self._num_curr_seqs -= num_curr_seqs
  73. @property
  74. def num_batched_tokens(self):
  75. return self._num_batched_tokens
  76. @property
  77. def num_curr_seqs(self):
  78. return self._num_curr_seqs
  79. @dataclass
  80. class ScheduledSequenceGroup:
  81. # A sequence group that's scheduled.
  82. seq_group: SequenceGroup
  83. # The total chunk size (number of tokens) to process for next iteration.
  84. # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
  85. # chunked, it can be smaller than that.
  86. token_chunk_size: int
  87. @dataclass
  88. class SchedulerOutputs:
  89. """The scheduling decision made from a scheduler."""
  90. # Scheduled sequence groups.
  91. scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
  92. # Number of prefill groups scheduled.
  93. num_prefill_groups: int
  94. # Total number of batched tokens.
  95. num_batched_tokens: int
  96. # Blocks to swap in. Dict of CPU -> GPU block number.
  97. blocks_to_swap_in: Dict[int, int]
  98. # Blocks to swap out. Dict of GPU -> CPU block number.
  99. blocks_to_swap_out: Dict[int, int]
  100. # Blocks to copy. Source to dest block.
  101. blocks_to_copy: List[Tuple[int, int]]
  102. # Sequence groups that are going to be ignored.
  103. ignored_seq_groups: List[SequenceGroup]
  104. # The number of slots for lookahead decoding.
  105. num_lookahead_slots: int
  106. # The number of requests in the running queue
  107. running_queue_size: int
  108. def __post_init__(self):
  109. # Swap in and swap out should never happen at the same time.
  110. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
  111. self.num_loras: int = len(self.lora_requests)
  112. if self.num_loras > 0:
  113. self._sort_by_lora_ids()
  114. def is_empty(self) -> bool:
  115. # NOTE: We do not consider the ignored sequence groups.
  116. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  117. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  118. def _sort_by_lora_ids(self):
  119. self.scheduled_seq_groups = sorted(
  120. self.scheduled_seq_groups,
  121. key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
  122. @property
  123. def lora_requests(self) -> Set[LoRARequest]:
  124. return {
  125. g.seq_group.lora_request
  126. for g in self.scheduled_seq_groups
  127. if g.seq_group.lora_request is not None
  128. }
  129. @dataclass
  130. class SchedulerRunningOutputs:
  131. """The requests that are scheduled from a running queue.
  132. Could contain prefill (prefill that's chunked) or decodes. If there's not
  133. enough memory, it can be preempted (for recompute) or swapped out.
  134. """
  135. # Selected sequences that are running and in a decoding phase.
  136. decode_seq_groups: List[SequenceGroup]
  137. # Selected sequences that are running and in a prefill phase.
  138. # I.e., it means the prefill has been chunked.
  139. prefill_seq_groups: List[SequenceGroup]
  140. # The preempted sequences.
  141. preempted: List[SequenceGroup]
  142. # Sequences that are swapped out.
  143. swapped_out: List[SequenceGroup]
  144. # The blocks to swap out.
  145. blocks_to_swap_out: Dict[int, int]
  146. # The blocks to copy.
  147. blocks_to_copy: List[Tuple[int, int]]
  148. # The number of slots for lookahead decoding.
  149. num_lookahead_slots: int
  150. @classmethod
  151. def create_empty(cls) -> "SchedulerRunningOutputs":
  152. return SchedulerRunningOutputs(
  153. decode_seq_groups=[],
  154. prefill_seq_groups=[],
  155. preempted=[],
  156. swapped_out=[],
  157. blocks_to_swap_out={},
  158. blocks_to_copy=[],
  159. num_lookahead_slots=0,
  160. )
  161. @dataclass
  162. class SchedulerSwappedInOutputs:
  163. """The requests that are scheduled from a swap queue.
  164. Could contain prefill (prefill that's chunked) or decodes.
  165. """
  166. # Selected sequences that are going to be swapped in and is in a
  167. # decoding phase.
  168. decode_seq_groups: List[SequenceGroup]
  169. # Selected sequences that are going to be swapped in and in a prefill
  170. # phase. I.e., it means the prefill has been chunked.
  171. prefill_seq_groups: List[SequenceGroup]
  172. # The blocks to swap in.
  173. blocks_to_swap_in: Dict[int, int]
  174. # The blocks to copy.
  175. blocks_to_copy: List[Tuple[int, int]]
  176. # The number of slots for lookahead decoding.
  177. num_lookahead_slots: int
  178. # Infeasible sequence groups.
  179. infeasible_seq_groups: List[SequenceGroup]
  180. @classmethod
  181. def create_empty(cls) -> "SchedulerSwappedInOutputs":
  182. return SchedulerSwappedInOutputs(
  183. decode_seq_groups=[],
  184. prefill_seq_groups=[],
  185. blocks_to_swap_in={},
  186. blocks_to_copy=[],
  187. num_lookahead_slots=0,
  188. infeasible_seq_groups=[],
  189. )
  190. @dataclass
  191. class SchedulerPrefillOutputs:
  192. """The requests that are scheduled from a waiting queue.
  193. Could contain a fresh prefill requests or preempted requests that need
  194. to be recomputed from scratch.
  195. """
  196. # Selected sequences for prefill.
  197. seq_groups: List[SequenceGroup]
  198. # Ignored sequence groups.
  199. ignored_seq_groups: List[SequenceGroup]
  200. num_lookahead_slots: int
  201. @classmethod
  202. def create_empty(cls) -> "SchedulerPrefillOutputs":
  203. return SchedulerPrefillOutputs(
  204. seq_groups=[],
  205. ignored_seq_groups=[],
  206. num_lookahead_slots=0,
  207. )
  208. class Scheduler:
  209. def __init__(
  210. self,
  211. scheduler_config: SchedulerConfig,
  212. cache_config: CacheConfig,
  213. lora_config: Optional[LoRAConfig],
  214. ) -> None:
  215. self.scheduler_config = scheduler_config
  216. self.cache_config = cache_config
  217. # Note for LoRA scheduling: the current policy is extremely
  218. # simple and NOT fair. It can lead to starvation of some
  219. # LoRAs. This should be improved in the future.
  220. self.lora_config = lora_config
  221. if self.scheduler_config.chunked_prefill_enabled:
  222. self.prompt_limit = self.scheduler_config.max_model_len
  223. else:
  224. self.prompt_limit = min(
  225. self.scheduler_config.max_model_len,
  226. self.scheduler_config.max_num_batched_tokens)
  227. BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
  228. version="v2" if self.scheduler_config.
  229. use_v2_block_manager else "v1")
  230. # Create the block space manager.
  231. self.block_manager = BlockSpaceManagerImpl(
  232. block_size=self.cache_config.block_size,
  233. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  234. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  235. sliding_window=self.cache_config.sliding_window,
  236. enable_caching=self.cache_config.enable_prefix_caching)
  237. # Sequence groups in the WAITING state.
  238. # Contain new prefill or preempted requests.
  239. self.waiting: Deque[SequenceGroup] = deque()
  240. # Sequence groups in the RUNNING state.
  241. # Contain decode requests.
  242. self.running: Deque[SequenceGroup] = deque()
  243. # Sequence groups in the SWAPPED state.
  244. # Contain decode requests that are swapped out.
  245. self.swapped: Deque[SequenceGroup] = deque()
  246. # Time at previous scheduling step
  247. self.prev_time = 0.0
  248. # Did we schedule a prompt at previous step?
  249. self.prev_prompt = False
  250. # Latency of the last prompt step
  251. self.last_prompt_latency = 0.0
  252. # The following field is test-only. It is used to inject artificial
  253. # preemption.
  254. self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
  255. self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
  256. if self.enable_artificial_preemption
  257. else 0)
  258. @property
  259. def lora_enabled(self) -> bool:
  260. return bool(self.lora_config)
  261. @property
  262. def num_decoding_tokens_per_seq(self) -> int:
  263. """The number of new tokens."""
  264. return 1
  265. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  266. # Add sequence groups to the waiting queue.
  267. self.waiting.append(seq_group)
  268. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  269. """Aborts a sequence group with the given ID.
  270. Check if the sequence group with the given ID
  271. is present in any of the state queue.
  272. If present, remove the sequence group from the state queue.
  273. Also, if any of the sequences in the sequence group is not finished,
  274. free the sequence with status `FINISHED_ABORTED`.
  275. Otherwise, do nothing.
  276. Args:
  277. request_id: The ID(s) of the sequence group to abort.
  278. """
  279. if isinstance(request_id, str):
  280. request_id = (request_id, )
  281. request_ids = set(request_id)
  282. for state_queue in [self.waiting, self.running, self.swapped]:
  283. aborted_groups: List[SequenceGroup] = []
  284. for seq_group in state_queue:
  285. if not request_ids:
  286. # Using 'break' here may add two extra iterations,
  287. # but is acceptable to reduce complexity.
  288. break
  289. if seq_group.request_id in request_ids:
  290. # Appending aborted group into pending list.
  291. aborted_groups.append(seq_group)
  292. request_ids.remove(seq_group.request_id)
  293. for aborted_group in aborted_groups:
  294. # Remove the sequence group from the state queue.
  295. state_queue.remove(aborted_group)
  296. for seq in aborted_group.get_seqs():
  297. if seq.is_finished():
  298. continue
  299. seq.status = SequenceStatus.FINISHED_ABORTED
  300. self.free_seq(seq)
  301. def has_unfinished_seqs(self) -> bool:
  302. return len(self.waiting) != 0 or len(self.running) != 0 or len(
  303. self.swapped) != 0
  304. def get_num_unfinished_seq_groups(self) -> int:
  305. return len(self.waiting) + len(self.running) + len(self.swapped)
  306. def _schedule_running(
  307. self,
  308. running_queue: deque,
  309. budget: SchedulingBudget,
  310. curr_loras: Optional[Set[int]],
  311. policy: Policy,
  312. enable_chunking: bool = False,
  313. ) -> Tuple[deque, SchedulerRunningOutputs]:
  314. """Schedule sequence groups that are running.
  315. Running queue should include decode and chunked prefill requests.
  316. Args:
  317. running_queue: The queue that contains running requests (i.e.,
  318. decodes). The given arguments are NOT in-place modified.
  319. budget: The scheduling budget. The argument is in-place updated
  320. when any decodes are preempted.
  321. curr_loras: Currently batched lora request ids. The argument is
  322. in-place updated when any decodes are preempted.
  323. policy: The sorting policy to sort running_queue.
  324. enable_chunking: If True, seq group can be chunked and only a
  325. chunked number of tokens are scheduled if
  326. `budget.num_batched_tokens` has not enough capacity to schedule
  327. all tokens.
  328. Returns:
  329. A tuple of remaining running queue (should be always 0) after
  330. scheduling and SchedulerRunningOutputs.
  331. """
  332. # Blocks that need to be swapped or copied before model execution.
  333. blocks_to_swap_out: Dict[int, int] = {}
  334. blocks_to_copy: List[Tuple[int, int]] = []
  335. decode_seq_groups: List[ScheduledSequenceGroup] = []
  336. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  337. preempted: List[SequenceGroup] = []
  338. swapped_out: List[SequenceGroup] = []
  339. # NOTE(woosuk): Preemption happens only when there is no available slot
  340. # to keep all the sequence groups in the RUNNING state.
  341. # In this case, the policy is responsible for deciding which sequence
  342. # groups to preempt.
  343. now = time.time()
  344. running_queue = policy.sort_by_priority(now, running_queue)
  345. while running_queue:
  346. seq_group = running_queue[0]
  347. num_running_tokens = self._get_num_new_tokens(
  348. seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
  349. if num_running_tokens == 0:
  350. break
  351. running_queue.popleft()
  352. while not self._can_append_slots(seq_group):
  353. budget.subtract_num_batched_tokens(seq_group.request_id,
  354. num_running_tokens)
  355. num_running_seqs = seq_group.get_max_num_running_seqs()
  356. budget.subtract_num_seqs(seq_group.request_id,
  357. num_running_seqs)
  358. if curr_loras is not None and seq_group.lora_int_id > 0:
  359. curr_loras.remove(seq_group.lora_int_id)
  360. if running_queue:
  361. # Preempt the lowest-priority sequence groups.
  362. victim_seq_group = running_queue.pop()
  363. preempted_mode = self._preempt(victim_seq_group,
  364. blocks_to_swap_out)
  365. if preempted_mode == PreemptionMode.RECOMPUTE:
  366. preempted.append(victim_seq_group)
  367. else:
  368. swapped_out.append(victim_seq_group)
  369. else:
  370. # No other sequence groups can be preempted.
  371. # Preempt the current sequence group.
  372. preempted_mode = self._preempt(seq_group,
  373. blocks_to_swap_out)
  374. if preempted_mode == PreemptionMode.RECOMPUTE:
  375. preempted.append(seq_group)
  376. else:
  377. swapped_out.append(seq_group)
  378. break
  379. else:
  380. self._append_slots(seq_group, blocks_to_copy)
  381. is_prefill = seq_group.is_prefill()
  382. if is_prefill:
  383. prefill_seq_groups.append(
  384. ScheduledSequenceGroup(
  385. seq_group=seq_group,
  386. token_chunk_size=num_running_tokens))
  387. else:
  388. decode_seq_groups.append(
  389. ScheduledSequenceGroup(seq_group=seq_group,
  390. token_chunk_size=1))
  391. budget.add_num_batched_tokens(seq_group.request_id,
  392. num_running_tokens)
  393. # OPTIMIZATION: Note that get_max_num_running_seqs is
  394. # expensive. For the default scheduling chase where
  395. # enable_chunking is False, num_seqs are updated before running
  396. # this method, so we don't have to update it again here.
  397. if enable_chunking:
  398. num_running_seqs = seq_group.get_max_num_running_seqs()
  399. budget.add_num_seqs(seq_group.request_id, num_running_seqs)
  400. if curr_loras is not None and seq_group.lora_int_id > 0:
  401. curr_loras.add(seq_group.lora_int_id)
  402. return running_queue, SchedulerRunningOutputs(
  403. decode_seq_groups=decode_seq_groups,
  404. prefill_seq_groups=prefill_seq_groups,
  405. preempted=preempted,
  406. swapped_out=swapped_out,
  407. blocks_to_swap_out=blocks_to_swap_out,
  408. blocks_to_copy=blocks_to_copy,
  409. num_lookahead_slots=self._get_num_lookahead_slots(
  410. is_prefill=False))
  411. def _schedule_swapped(
  412. self,
  413. swapped_queue: deque,
  414. budget: SchedulingBudget,
  415. curr_loras: Optional[Set[int]],
  416. policy: Policy,
  417. enable_chunking: bool = False,
  418. ) -> Tuple[deque, SchedulerSwappedInOutputs]:
  419. """Schedule sequence groups that are swapped out.
  420. It schedules swapped requests as long as it fits `budget` and
  421. curr_loras <= max_lora from the scheduling config. The input arguments
  422. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  423. Args:
  424. swapped_queue: The queue that contains swapped out requests.
  425. The given arguments are NOT in-place modified.
  426. budget: The scheduling budget. The argument is in-place updated
  427. when any requests are swapped in.
  428. curr_loras: Currently batched lora request ids. The argument is
  429. in-place updated when any requests are swapped in.
  430. policy: The sorting policy to sort swapped_queue.
  431. enable_chunking: If True, seq group can be chunked and only a
  432. chunked number of tokens are scheduled if
  433. `budget.num_batched_tokens` has not enough capacity to schedule
  434. all tokens.
  435. Returns:
  436. A tuple of remaining swapped_queue after scheduling and
  437. SchedulerSwappedInOutputs.
  438. """
  439. # Blocks that need to be swapped or copied before model execution.
  440. blocks_to_swap_in: Dict[int, int] = {}
  441. blocks_to_copy: List[Tuple[int, int]] = []
  442. decode_seq_groups: List[ScheduledSequenceGroup] = []
  443. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  444. now = time.time()
  445. swapped_queue = policy.sort_by_priority(now, swapped_queue)
  446. infeasible_seq_groups: List[SequenceGroup] = []
  447. leftover_swapped: Deque[SequenceGroup] = deque()
  448. while swapped_queue:
  449. seq_group = swapped_queue[0]
  450. # If the sequence group cannot be swapped in, stop.
  451. alloc_status = self.block_manager.can_swap_in(seq_group)
  452. if alloc_status == AllocStatus.LATER:
  453. break
  454. elif alloc_status == AllocStatus.NEVER:
  455. logger.warning(
  456. "Failing the request %s because there's not enough kv "
  457. "cache blocks to run the entire sequence.",
  458. seq_group.request_id)
  459. for seq in seq_group.get_seqs():
  460. seq.status = SequenceStatus.FINISHED_IGNORED
  461. infeasible_seq_groups.append(seq_group)
  462. swapped_queue.popleft()
  463. continue
  464. lora_int_id = 0
  465. if self.lora_enabled:
  466. lora_int_id = seq_group.lora_int_id
  467. assert curr_loras is not None
  468. assert self.lora_config is not None
  469. if (lora_int_id > 0 and (lora_int_id not in curr_loras)
  470. and len(curr_loras) >= self.lora_config.max_loras):
  471. # We don't have a space for another LoRA, so
  472. # we ignore this request for now.
  473. leftover_swapped.appendleft(seq_group)
  474. swapped_queue.popleft()
  475. continue
  476. # The total number of sequences in the RUNNING state should not
  477. # exceed the maximum number of sequences.
  478. num_new_seqs = seq_group.get_max_num_running_seqs()
  479. num_new_tokens = self._get_num_new_tokens(seq_group,
  480. SequenceStatus.SWAPPED,
  481. enable_chunking, budget)
  482. if (num_new_tokens == 0
  483. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  484. num_new_seqs=num_new_seqs)):
  485. break
  486. if lora_int_id > 0 and curr_loras is not None:
  487. curr_loras.add(lora_int_id)
  488. swapped_queue.popleft()
  489. self._swap_in(seq_group, blocks_to_swap_in)
  490. self._append_slots(seq_group, blocks_to_copy)
  491. is_prefill = seq_group.is_prefill()
  492. if is_prefill:
  493. prefill_seq_groups.append(
  494. ScheduledSequenceGroup(seq_group,
  495. token_chunk_size=num_new_tokens))
  496. else:
  497. decode_seq_groups.append(
  498. ScheduledSequenceGroup(seq_group, token_chunk_size=1))
  499. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  500. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  501. swapped_queue.extendleft(leftover_swapped)
  502. return swapped_queue, SchedulerSwappedInOutputs(
  503. decode_seq_groups=decode_seq_groups,
  504. prefill_seq_groups=prefill_seq_groups,
  505. blocks_to_swap_in=blocks_to_swap_in,
  506. blocks_to_copy=blocks_to_copy,
  507. num_lookahead_slots=self._get_num_lookahead_slots(
  508. is_prefill=False),
  509. infeasible_seq_groups=infeasible_seq_groups,
  510. )
  511. def _schedule_prefills(
  512. self,
  513. waiting_queue: deque,
  514. budget: SchedulingBudget,
  515. curr_loras: Optional[Set[int]],
  516. enable_chunking: bool = False,
  517. ) -> Tuple[deque, SchedulerPrefillOutputs]:
  518. """Schedule sequence groups that are in prefill stage.
  519. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
  520. as a new prefill (that starts from beginning -> most recently generated
  521. tokens).
  522. It schedules waiting requests as long as it fits `budget` and
  523. curr_loras <= max_lora from the scheduling config. The input arguments
  524. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  525. Args:
  526. waiting_queue: The queue that contains prefill requests.
  527. The given arguments are NOT in-place modified.
  528. budget: The scheduling budget. The argument is in-place updated
  529. when any requests are scheduled.
  530. curr_loras: Currently batched lora request ids. The argument is
  531. in-place updated when any requests are scheduled.
  532. enable_chunking: If True, seq group can be chunked and only a
  533. chunked number of tokens are scheduled if
  534. `budget.num_batched_tokens` has not enough capacity to schedule
  535. all tokens.
  536. Returns:
  537. A tuple of remaining waiting_queue after scheduling and
  538. SchedulerSwappedInOutputs.
  539. """
  540. ignored_seq_groups: List[SequenceGroup] = []
  541. seq_groups: List[SequenceGroup] = []
  542. # We don't sort waiting queue because we assume it is sorted.
  543. # Copy the queue so that the input queue is not modified.
  544. waiting_queue = deque([s for s in waiting_queue])
  545. leftover_waiting_sequences: Deque[SequenceGroup] = deque()
  546. while self._passed_delay(time.time()) and waiting_queue:
  547. seq_group = waiting_queue[0]
  548. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  549. assert len(waiting_seqs) == 1, (
  550. "Waiting sequence group should have only one prompt "
  551. "sequence.")
  552. num_new_tokens = self._get_num_new_tokens(seq_group,
  553. SequenceStatus.WAITING,
  554. enable_chunking, budget)
  555. if not enable_chunking:
  556. num_prompt_tokens = waiting_seqs[0].get_len()
  557. assert num_new_tokens == num_prompt_tokens
  558. if num_new_tokens > self.prompt_limit:
  559. logger.warning(
  560. "Input prompt (%d tokens) is too long"
  561. " and exceeds limit of %d", num_new_tokens,
  562. self.prompt_limit)
  563. for seq in waiting_seqs:
  564. seq.status = SequenceStatus.FINISHED_IGNORED
  565. ignored_seq_groups.append(seq_group)
  566. waiting_queue.popleft()
  567. continue
  568. # If the sequence group cannot be allocated, stop.
  569. can_allocate = self.block_manager.can_allocate(seq_group)
  570. if can_allocate == AllocStatus.LATER:
  571. break
  572. elif can_allocate == AllocStatus.NEVER:
  573. logger.warning(
  574. "Input prompt (%d tokens) is too long"
  575. " and exceeds the capacity of block_manager",
  576. num_new_tokens)
  577. for seq in waiting_seqs:
  578. seq.status = SequenceStatus.FINISHED_IGNORED
  579. ignored_seq_groups.append(seq_group)
  580. waiting_queue.popleft()
  581. continue
  582. lora_int_id = 0
  583. if self.lora_enabled:
  584. lora_int_id = seq_group.lora_int_id
  585. assert curr_loras is not None
  586. assert self.lora_config is not None
  587. if (self.lora_enabled and lora_int_id > 0
  588. and lora_int_id not in curr_loras
  589. and len(curr_loras) >= self.lora_config.max_loras):
  590. # We don't have a space for another LoRA, so
  591. # we ignore this request for now.
  592. leftover_waiting_sequences.appendleft(seq_group)
  593. waiting_queue.popleft()
  594. continue
  595. num_new_seqs = seq_group.get_max_num_running_seqs()
  596. if (num_new_tokens == 0
  597. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  598. num_new_seqs=num_new_seqs)):
  599. break
  600. # Can schedule this request.
  601. if curr_loras is not None and lora_int_id > 0:
  602. curr_loras.add(lora_int_id)
  603. waiting_queue.popleft()
  604. self._allocate_and_set_running(seq_group)
  605. seq_groups.append(
  606. ScheduledSequenceGroup(seq_group=seq_group,
  607. token_chunk_size=num_new_tokens))
  608. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  609. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  610. # Queue requests that couldn't be scheduled.
  611. waiting_queue.extendleft(leftover_waiting_sequences)
  612. if len(seq_groups) > 0:
  613. self.prev_prompt = True
  614. return waiting_queue, SchedulerPrefillOutputs(
  615. seq_groups=seq_groups,
  616. ignored_seq_groups=ignored_seq_groups,
  617. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
  618. def _schedule_default(self) -> SchedulerOutputs:
  619. """Schedule queued requests.
  620. The current policy is designed to optimize the throughput. First,
  621. it batches as many prefill requests as possible. And it schedules
  622. decodes. If there's a pressure on GPU memory, decode requests can
  623. be swapped or preempted.
  624. """
  625. # Include running requests to the budget.
  626. budget = SchedulingBudget(
  627. token_budget=self.scheduler_config.max_num_batched_tokens,
  628. max_num_seqs=self.scheduler_config.max_num_seqs,
  629. )
  630. # Make sure we include num running seqs before scheduling prefill,
  631. # so that we don't schedule beyond max_num_seqs for prefill.
  632. for seq_group in self.running:
  633. budget.add_num_seqs(seq_group.request_id,
  634. seq_group.get_max_num_running_seqs())
  635. curr_loras = set(
  636. seq_group.lora_int_id
  637. for seq_group in self.running) if self.lora_enabled else None
  638. remaining_waiting, prefills = (self.waiting,
  639. SchedulerPrefillOutputs.create_empty())
  640. remaining_running, running_scheduled = (
  641. self.running, SchedulerRunningOutputs.create_empty())
  642. remaining_swapped, swapped_in = (
  643. self.swapped, SchedulerSwappedInOutputs.create_empty())
  644. # If any requests are swapped, prioritized swapped requests.
  645. if not self.swapped:
  646. remaining_waiting, prefills = self._schedule_prefills(
  647. self.waiting, budget, curr_loras, enable_chunking=False)
  648. fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
  649. # Don't schedule decodes if prefills are scheduled.
  650. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
  651. # only contains decode requests, not chunked prefills.
  652. if len(prefills.seq_groups) == 0:
  653. remaining_running, running_scheduled = self._schedule_running(
  654. self.running,
  655. budget,
  656. curr_loras,
  657. fcfs_policy,
  658. enable_chunking=False)
  659. # If any sequence group is preempted, do not swap in any sequence
  660. # group. because it means there's no slot for new running requests.
  661. if len(running_scheduled.preempted) + len(
  662. running_scheduled.swapped_out) == 0:
  663. remaining_swapped, swapped_in = self._schedule_swapped(
  664. self.swapped, budget, curr_loras, fcfs_policy)
  665. assert (budget.num_batched_tokens <=
  666. self.scheduler_config.max_num_batched_tokens)
  667. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  668. # Update waiting requests.
  669. self.waiting = remaining_waiting
  670. self.waiting.extendleft(running_scheduled.preempted)
  671. # Update new running requests.
  672. self.running = remaining_running
  673. self.running.extend([s.seq_group for s in prefills.seq_groups])
  674. self.running.extend(
  675. [s.seq_group for s in running_scheduled.decode_seq_groups])
  676. self.running.extend(
  677. [s.seq_group for s in swapped_in.decode_seq_groups])
  678. # Update swapped requests.
  679. self.swapped = remaining_swapped
  680. self.swapped.extend(running_scheduled.swapped_out)
  681. # There should be no prefill from running queue because this policy
  682. # doesn't allow chunked prefills.
  683. assert len(running_scheduled.prefill_seq_groups) == 0
  684. assert len(swapped_in.prefill_seq_groups) == 0
  685. return SchedulerOutputs(
  686. scheduled_seq_groups=(prefills.seq_groups +
  687. running_scheduled.decode_seq_groups +
  688. swapped_in.decode_seq_groups),
  689. num_prefill_groups=len(prefills.seq_groups),
  690. num_batched_tokens=budget.num_batched_tokens,
  691. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  692. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  693. blocks_to_copy=running_scheduled.blocks_to_copy +
  694. swapped_in.blocks_to_copy,
  695. ignored_seq_groups=prefills.ignored_seq_groups +
  696. swapped_in.infeasible_seq_groups,
  697. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  698. running_queue_size=len(self.running),
  699. )
  700. def _schedule_chunked_prefill(self):
  701. """Schedule queued requests.
  702. Chunked prefill allows to chunk prefill requests, batch them together
  703. with decode requests. This policy 1. schedule as many decoding requests
  704. as possible. 2. schedule chunked prefill requests that are not
  705. finished. 3. schedule swapped request. 4. schedule new prefill
  706. requests.
  707. The policy can sustain the high GPU utilization because it can put
  708. prefill and decodes requests to the same batch, while it improves
  709. inter token latency because decodes requests don't need to blocked
  710. by prefill requests.
  711. """
  712. budget = SchedulingBudget(
  713. token_budget=self.scheduler_config.max_num_batched_tokens,
  714. max_num_seqs=self.scheduler_config.max_num_seqs,
  715. )
  716. curr_loras: Set[int] = set()
  717. remaining_waiting, prefills = (self.waiting,
  718. SchedulerPrefillOutputs.create_empty())
  719. remaining_running, running_scheduled = (
  720. self.running, SchedulerRunningOutputs.create_empty())
  721. remaining_swapped, swapped_in = (
  722. self.swapped, SchedulerSwappedInOutputs.create_empty())
  723. # Decoding should be always scheduled first by fcfs.
  724. fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
  725. remaining_running, running_scheduled = self._schedule_running(
  726. self.running,
  727. budget,
  728. curr_loras,
  729. fcfs_policy,
  730. enable_chunking=True)
  731. # Schedule swapped out requests.
  732. # If preemption happens, it means we don't have space for swap-in.
  733. if len(running_scheduled.preempted) + len(
  734. running_scheduled.swapped_out) == 0:
  735. remaining_swapped, swapped_in = self._schedule_swapped(
  736. self.swapped, budget, curr_loras, fcfs_policy)
  737. # Schedule new prefills.
  738. remaining_waiting, prefills = self._schedule_prefills(
  739. self.waiting, budget, curr_loras, enable_chunking=True)
  740. assert (budget.num_batched_tokens <=
  741. self.scheduler_config.max_num_batched_tokens)
  742. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  743. # Update waiting requests.
  744. self.waiting = remaining_waiting
  745. self.waiting.extendleft(running_scheduled.preempted)
  746. # Update new running requests.
  747. self.running = remaining_running
  748. self.running.extend([s.seq_group for s in prefills.seq_groups])
  749. self.running.extend(
  750. [s.seq_group for s in running_scheduled.decode_seq_groups])
  751. self.running.extend(
  752. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  753. self.running.extend(
  754. [s.seq_group for s in swapped_in.decode_seq_groups])
  755. self.running.extend(
  756. [s.seq_group for s in swapped_in.prefill_seq_groups])
  757. # Update swapped requests.
  758. self.swapped = remaining_swapped
  759. self.swapped.extend(running_scheduled.swapped_out)
  760. return SchedulerOutputs(
  761. scheduled_seq_groups=(prefills.seq_groups +
  762. running_scheduled.prefill_seq_groups +
  763. swapped_in.prefill_seq_groups +
  764. running_scheduled.decode_seq_groups +
  765. swapped_in.decode_seq_groups),
  766. num_prefill_groups=(len(prefills.seq_groups) +
  767. len(swapped_in.prefill_seq_groups) +
  768. len(running_scheduled.prefill_seq_groups)),
  769. num_batched_tokens=budget.num_batched_tokens,
  770. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  771. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  772. blocks_to_copy=running_scheduled.blocks_to_copy +
  773. swapped_in.blocks_to_copy,
  774. ignored_seq_groups=prefills.ignored_seq_groups,
  775. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  776. running_queue_size=len(self.running),
  777. )
  778. def _schedule(self) -> SchedulerOutputs:
  779. """Schedule queued requests."""
  780. if self.scheduler_config.chunked_prefill_enabled:
  781. return self._schedule_chunked_prefill()
  782. else:
  783. return self._schedule_default()
  784. def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
  785. """Determine whether or not we have enough space in the KV cache to
  786. continue generation of the sequence group.
  787. """
  788. # It is True only for testing case to trigger artificial preemption.
  789. if (self.enable_artificial_preemption
  790. and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
  791. and self.artificial_preempt_cnt > 0):
  792. self.artificial_preempt_cnt -= 1
  793. return False
  794. # Appending slots only occurs in decoding.
  795. is_prefill = False
  796. return self.block_manager.can_append_slots(
  797. seq_group=seq_group,
  798. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  799. )
  800. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  801. # Schedule sequence groups.
  802. # This function call changes the internal states of the scheduler
  803. # such as self.running, self.swapped, and self.waiting.
  804. scheduler_outputs = self._schedule()
  805. now = time.time()
  806. # Create input data structures.
  807. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  808. for i, scheduled_seq_group in enumerate(
  809. scheduler_outputs.scheduled_seq_groups):
  810. seq_group = scheduled_seq_group.seq_group
  811. token_chunk_size = scheduled_seq_group.token_chunk_size
  812. seq_group.maybe_set_first_scheduled_time(now)
  813. # seq_id -> SequenceData
  814. seq_data: Dict[int, SequenceData] = {}
  815. # seq_id -> physical block numbers
  816. block_tables: Dict[int, List[int]] = {}
  817. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  818. seq_id = seq.seq_id
  819. seq_data[seq_id] = seq.data
  820. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  821. self.block_manager.access_all_blocks_in_seq(seq, now)
  822. common_computed_block_nums = (
  823. self.block_manager.get_common_computed_block_ids(
  824. seq_group.get_seqs(status=SequenceStatus.RUNNING)))
  825. do_sample = True
  826. if seq_group.is_prefill():
  827. seqs = seq_group.get_seqs()
  828. # Prefill has only 1 sequence.
  829. assert len(seqs) == 1
  830. # In the next iteration, all prompt tokens are not computed.
  831. # It means the prefill is chunked, and we don't need sampling.
  832. # NOTE: We use get_len instead of get_prompt_len because when
  833. # a sequence is preempted, prefill includes previous generated
  834. # output tokens.
  835. if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
  836. seqs[0].data.get_len()):
  837. do_sample = False
  838. # It assumes the scheduled_seq_groups is ordered by
  839. # prefill < decoding.
  840. is_prompt = seq_group.is_prefill()
  841. seq_group_metadata = SequenceGroupMetadata(
  842. request_id=seq_group.request_id,
  843. is_prompt=is_prompt,
  844. seq_data=seq_data,
  845. sampling_params=seq_group.sampling_params,
  846. block_tables=block_tables,
  847. do_sample=do_sample,
  848. token_chunk_size=token_chunk_size,
  849. lora_request=seq_group.lora_request,
  850. computed_block_nums=common_computed_block_nums,
  851. state=seq_group.state,
  852. # `multi_modal_data` will only be present for the 1st comm
  853. # between engine and worker.
  854. # the subsequent comms can still use delta, but
  855. # `multi_modal_data` will be None.
  856. multi_modal_data=seq_group.multi_modal_data
  857. if scheduler_outputs.num_prefill_groups > 0 else None,
  858. )
  859. seq_group_metadata_list.append(seq_group_metadata)
  860. # Now that the batch has been created, we can assume all blocks in the
  861. # batch will have been computed before the next scheduling invocation.
  862. # This is because the engine assumes that a failure in model execution
  863. # will crash the vLLM instance / will not retry.
  864. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
  865. self.block_manager.mark_blocks_as_computed(
  866. scheduled_seq_group.seq_group)
  867. return seq_group_metadata_list, scheduler_outputs
  868. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  869. self.block_manager.fork(parent_seq, child_seq)
  870. def free_seq(self, seq: Sequence) -> None:
  871. """Free a sequence from a block table."""
  872. self.block_manager.free(seq)
  873. def free_finished_seq_groups(self) -> None:
  874. self.running = deque(seq_group for seq_group in self.running
  875. if not seq_group.is_finished())
  876. def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
  877. self.block_manager.allocate(seq_group)
  878. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  879. seq.status = SequenceStatus.RUNNING
  880. def _append_slots(
  881. self,
  882. seq_group: SequenceGroup,
  883. blocks_to_copy: List[Tuple[int, int]],
  884. ) -> None:
  885. """Appends new slots to the sequences in the given sequence group.
  886. Args:
  887. seq_group (SequenceGroup): The sequence group containing the
  888. sequences to append slots to.
  889. blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
  890. ints, the first int is the source block index, and the second
  891. int is the destination block index. This list is updated with
  892. the new source and destination block indices for the appended
  893. slots.
  894. """
  895. num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
  896. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  897. cows = self.block_manager.append_slots(seq, num_lookahead_slots)
  898. blocks_to_copy.extend(cows)
  899. def _preempt(
  900. self,
  901. seq_group: SequenceGroup,
  902. blocks_to_swap_out: Dict[int, int],
  903. preemption_mode: Optional[PreemptionMode] = None,
  904. ) -> PreemptionMode:
  905. # If preemption mode is not specified, we determine the mode as follows:
  906. # We use recomputation by default since it incurs lower overhead than
  907. # swapping. However, when the sequence group has multiple sequences
  908. # (e.g., beam search), recomputation is not currently supported. In
  909. # such a case, we use swapping instead.
  910. # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
  911. # As swapped sequences are prioritized over waiting sequences,
  912. # sequence groups with multiple sequences are implicitly prioritized
  913. # over sequence groups with a single sequence.
  914. # TODO(woosuk): Support recomputation for sequence groups with multiple
  915. # sequences. This may require a more sophisticated CUDA kernel.
  916. if preemption_mode is None:
  917. if seq_group.get_max_num_running_seqs() == 1:
  918. preemption_mode = PreemptionMode.RECOMPUTE
  919. else:
  920. preemption_mode = PreemptionMode.SWAP
  921. if preemption_mode == PreemptionMode.RECOMPUTE:
  922. self._preempt_by_recompute(seq_group)
  923. elif preemption_mode == PreemptionMode.SWAP:
  924. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  925. else:
  926. raise AssertionError("Invalid preemption mode.")
  927. return preemption_mode
  928. def _preempt_by_recompute(
  929. self,
  930. seq_group: SequenceGroup,
  931. ) -> None:
  932. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  933. assert len(seqs) == 1
  934. for seq in seqs:
  935. seq.status = SequenceStatus.WAITING
  936. self.free_seq(seq)
  937. seq.reset_state_for_recompute()
  938. def _preempt_by_swap(
  939. self,
  940. seq_group: SequenceGroup,
  941. blocks_to_swap_out: Dict[int, int],
  942. ) -> None:
  943. self._swap_out(seq_group, blocks_to_swap_out)
  944. def _swap_in(
  945. self,
  946. seq_group: SequenceGroup,
  947. blocks_to_swap_in: Dict[int, int],
  948. ) -> None:
  949. mapping = self.block_manager.swap_in(seq_group)
  950. blocks_to_swap_in.update(mapping)
  951. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  952. seq.status = SequenceStatus.RUNNING
  953. def _swap_out(
  954. self,
  955. seq_group: SequenceGroup,
  956. blocks_to_swap_out: Dict[int, int],
  957. ) -> None:
  958. if not self.block_manager.can_swap_out(seq_group):
  959. # FIXME(woosuk): Abort the sequence group instead of aborting the
  960. # entire engine.
  961. raise RuntimeError(
  962. "Aborted due to the lack of CPU swap space. Please increase "
  963. "the swap space to avoid this error.")
  964. mapping = self.block_manager.swap_out(seq_group)
  965. blocks_to_swap_out.update(mapping)
  966. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  967. seq.status = SequenceStatus.SWAPPED
  968. def _passed_delay(self, now: float) -> bool:
  969. if self.prev_prompt:
  970. self.last_prompt_latency = now - self.prev_time
  971. self.prev_time, self.prev_prompt = now, False
  972. # Delay scheduling prompts to let waiting queue fill up
  973. if self.scheduler_config.delay_factor > 0 and self.waiting:
  974. earliest_arrival_time = min(
  975. [e.metrics.arrival_time for e in self.waiting])
  976. passed_delay = (
  977. (now - earliest_arrival_time) >
  978. (self.scheduler_config.delay_factor * self.last_prompt_latency)
  979. or not self.running)
  980. else:
  981. passed_delay = True
  982. return passed_delay
  983. def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
  984. """The number of slots to allocate per sequence per step, beyond known
  985. token ids. Speculative decoding uses these slots to store KV activations
  986. of tokens which may or may not be accepted.
  987. Speculative decoding does not yet support prefill, so we do not perform
  988. lookahead allocation for prefill.
  989. """
  990. if is_prefill:
  991. return 0
  992. return self.scheduler_config.num_lookahead_slots
  993. def _get_num_new_tokens(self, seq_group: SequenceGroup,
  994. status: SequenceStatus, enable_chunking: bool,
  995. budget: SchedulingBudget) -> int:
  996. """Get the next new tokens to compute for a given sequence group
  997. that's in a given `status`.
  998. The API could chunk the number of tokens to compute based on `budget`
  999. if `enable_chunking` is True. If a sequence group has multiple
  1000. sequences (e.g., running beam search), it means it is in decoding
  1001. phase, so chunking doesn't happen.
  1002. Returns 0 if the new token cannot be computed due to token budget.
  1003. """
  1004. num_new_tokens = 0
  1005. seqs = seq_group.get_seqs(status=status)
  1006. for seq in seqs:
  1007. num_new_tokens += seq.get_num_new_tokens()
  1008. assert num_new_tokens > 0
  1009. # Chunk if a running request cannot fit in.
  1010. # If number of seq > 1, it means it is doing beam search in a
  1011. # decode phase. Do not chunk in that case.
  1012. if enable_chunking and len(seqs) == 1:
  1013. num_new_tokens = min(num_new_tokens,
  1014. budget.remaining_token_budget())
  1015. return num_new_tokens