scheduler.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148
  1. import enum
  2. import os
  3. import random
  4. import time
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
  8. from loguru import logger
  9. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  10. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  11. SequenceGroupMetadata, SequenceStatus)
  12. from aphrodite.common.utils import merge_dicts
  13. from aphrodite.lora.request import LoRARequest
  14. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  15. from aphrodite.processing.policy import Policy, PolicyFactory
  16. # Test-only. If configured, decode is preempted with
  17. # ARTIFICIAL_PREEMPTION_PROB% probability.
  18. ENABLE_ARTIFICIAL_PREEMPT = bool(
  19. os.getenv("APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
  20. ARTIFICIAL_PREEMPTION_PROB = 0.5
  21. ARTIFICIAL_PREEMPTION_MAX_CNT = 500
  22. class PreemptionMode(enum.Enum):
  23. """Preemption modes.
  24. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  25. and swap them back in when the sequences are resumed.
  26. 2. Recomputation: Discard the blocks of the preempted sequences and
  27. recompute them when the sequences are resumed, treating the sequences as
  28. new prompts.
  29. """
  30. SWAP = enum.auto()
  31. RECOMPUTE = enum.auto()
  32. @dataclass
  33. class SchedulingBudget:
  34. """The available slots for scheduling.
  35. TODO: Right now, the budget is request_id-aware meaning it can ignore
  36. budget update from the same request_id. It is because in normal scheduling
  37. path, we update RUNNING num_seqs ahead of time, meaning it could be
  38. updated more than once when scheduling RUNNING requests. Since this won't
  39. happen if we only have chunked prefill scheduling, we can remove this
  40. feature from the API when chunked prefill is enabled by default.
  41. """
  42. token_budget: int
  43. max_num_seqs: int
  44. _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
  45. _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
  46. _num_batched_tokens: int = 0
  47. _num_curr_seqs: int = 0
  48. def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
  49. assert num_new_tokens != 0
  50. assert num_new_seqs != 0
  51. return (self.num_batched_tokens + num_new_tokens <= self.token_budget
  52. and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
  53. def remaining_token_budget(self):
  54. return self.token_budget - self.num_batched_tokens
  55. def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
  56. if req_id in self._requeset_ids_num_batched_tokens:
  57. return
  58. self._requeset_ids_num_batched_tokens.add(req_id)
  59. self._num_batched_tokens += num_batched_tokens
  60. def subtract_num_batched_tokens(self, req_id: str,
  61. num_batched_tokens: int):
  62. if req_id in self._requeset_ids_num_batched_tokens:
  63. self._requeset_ids_num_batched_tokens.remove(req_id)
  64. self._num_batched_tokens -= num_batched_tokens
  65. def add_num_seqs(self, req_id: str, num_curr_seqs: int):
  66. if req_id in self._requeset_ids_num_curr_seqs:
  67. return
  68. self._requeset_ids_num_curr_seqs.add(req_id)
  69. self._num_curr_seqs += num_curr_seqs
  70. def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
  71. if req_id in self._requeset_ids_num_curr_seqs:
  72. self._requeset_ids_num_curr_seqs.remove(req_id)
  73. self._num_curr_seqs -= num_curr_seqs
  74. @property
  75. def num_batched_tokens(self):
  76. return self._num_batched_tokens
  77. @property
  78. def num_curr_seqs(self):
  79. return self._num_curr_seqs
  80. @dataclass
  81. class ScheduledSequenceGroup:
  82. # A sequence group that's scheduled.
  83. seq_group: SequenceGroup
  84. # The total chunk size (number of tokens) to process for next iteration.
  85. # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
  86. # chunked, it can be smaller than that.
  87. token_chunk_size: int
  88. @dataclass
  89. class SchedulerOutputs:
  90. """The scheduling decision made from a scheduler."""
  91. # Scheduled sequence groups.
  92. scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
  93. # Number of prefill groups scheduled.
  94. num_prefill_groups: int
  95. # Total number of batched tokens.
  96. num_batched_tokens: int
  97. # Blocks to swap in. Dict of CPU -> GPU block number.
  98. blocks_to_swap_in: Dict[int, int]
  99. # Blocks to swap out. Dict of GPU -> CPU block number.
  100. blocks_to_swap_out: Dict[int, int]
  101. # Blocks to copy. Source to a list of dest blocks.
  102. blocks_to_copy: Dict[int, List[int]]
  103. # Sequence groups that are going to be ignored.
  104. ignored_seq_groups: List[SequenceGroup]
  105. # The number of slots for lookahead decoding.
  106. num_lookahead_slots: int
  107. def __post_init__(self):
  108. # Swap in and swap out should never happen at the same time.
  109. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
  110. self.num_loras: int = len(self.lora_requests)
  111. if self.num_loras > 0:
  112. self._sort_by_lora_ids()
  113. def is_empty(self) -> bool:
  114. # NOTE: We do not consider the ignored sequence groups.
  115. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  116. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  117. def _sort_by_lora_ids(self):
  118. self.scheduled_seq_groups = sorted(
  119. self.scheduled_seq_groups,
  120. key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
  121. @property
  122. def lora_requests(self) -> Set[LoRARequest]:
  123. return {
  124. g.seq_group.lora_request
  125. for g in self.scheduled_seq_groups
  126. if g.seq_group.lora_request is not None
  127. }
  128. @dataclass
  129. class SchedulerRunningOutputs:
  130. """The requests that are scheduled from a running queue.
  131. Could contain prefill (prefill that's chunked) or decodes. If there's not
  132. enough memory, it can be preempted (for recompute) or swapped out.
  133. """
  134. # Selected sequences that are running and in a decoding phase.
  135. decode_seq_groups: List[SequenceGroup]
  136. # Selected sequences that are running and in a prefill phase.
  137. # I.e., it means the prefill has been chunked.
  138. prefill_seq_groups: List[SequenceGroup]
  139. # The preempted sequences.
  140. preempted: List[SequenceGroup]
  141. # Sequences that are swapped out.
  142. swapped_out: List[SequenceGroup]
  143. # The blocks to swap out.
  144. blocks_to_swap_out: Dict[int, int]
  145. # The blocks to copy.
  146. blocks_to_copy: Dict[int, List[int]]
  147. # The number of slots for lookahead decoding.
  148. num_lookahead_slots: int
  149. @classmethod
  150. def create_empty(cls) -> "SchedulerRunningOutputs":
  151. return SchedulerRunningOutputs(
  152. decode_seq_groups=[],
  153. prefill_seq_groups=[],
  154. preempted=[],
  155. swapped_out=[],
  156. blocks_to_swap_out={},
  157. blocks_to_copy={},
  158. num_lookahead_slots=0,
  159. )
  160. @dataclass
  161. class SchedulerSwappedInOutputs:
  162. """The requests that are scheduled from a swap queue.
  163. Could contain prefill (prefill that's chunked) or decodes.
  164. """
  165. # Selected sequences that are going to be swapped in and is in a
  166. # decoding phase.
  167. decode_seq_groups: List[SequenceGroup]
  168. # Selected sequences that are going to be swapped in and in a prefill
  169. # phase. I.e., it means the prefill has been chunked.
  170. prefill_seq_groups: List[SequenceGroup]
  171. # The blocks to swap in.
  172. blocks_to_swap_in: Dict[int, int]
  173. # The blocks to copy.
  174. blocks_to_copy: Dict[int, List[int]]
  175. # The number of slots for lookahead decoding.
  176. num_lookahead_slots: int
  177. @classmethod
  178. def create_empty(cls) -> "SchedulerSwappedInOutputs":
  179. return SchedulerSwappedInOutputs(
  180. decode_seq_groups=[],
  181. prefill_seq_groups=[],
  182. blocks_to_swap_in={},
  183. blocks_to_copy={},
  184. num_lookahead_slots=0,
  185. )
  186. @dataclass
  187. class SchedulerPrefillOutputs:
  188. """The requests that are scheduled from a waiting queue.
  189. Could contain a fresh prefill requests or preempted requests that need
  190. to be recomputed from scratch.
  191. """
  192. # Selected sequences for prefill.
  193. seq_groups: List[SequenceGroup]
  194. # Ignored sequence groups.
  195. ignored_seq_groups: List[SequenceGroup]
  196. num_lookahead_slots: int
  197. @classmethod
  198. def create_empty(cls) -> "SchedulerPrefillOutputs":
  199. return SchedulerPrefillOutputs(
  200. seq_groups=[],
  201. ignored_seq_groups=[],
  202. num_lookahead_slots=0,
  203. )
  204. class Scheduler:
  205. def __init__(
  206. self,
  207. scheduler_config: SchedulerConfig,
  208. cache_config: CacheConfig,
  209. lora_config: Optional[LoRAConfig],
  210. ) -> None:
  211. self.scheduler_config = scheduler_config
  212. self.cache_config = cache_config
  213. # Note for LoRA scheduling: the current policy is extremely
  214. # simple and NOT fair. It can lead to starvation of some
  215. # LoRAs. This should be improved in the future.
  216. self.lora_config = lora_config
  217. if self.scheduler_config.chunked_prefill_enabled:
  218. self.prompt_limit = self.scheduler_config.max_model_len
  219. else:
  220. self.prompt_limit = min(
  221. self.scheduler_config.max_model_len,
  222. self.scheduler_config.max_num_batched_tokens)
  223. BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
  224. version="v2" if self.scheduler_config.
  225. use_v2_block_manager else "v1")
  226. # Create the block space manager.
  227. self.block_manager = BlockSpaceManagerImpl(
  228. block_size=self.cache_config.block_size,
  229. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  230. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  231. sliding_window=self.cache_config.sliding_window,
  232. enable_caching=self.cache_config.enable_prefix_caching)
  233. # Sequence groups in the WAITING state.
  234. # Contain new prefill or preempted requests.
  235. self.waiting: Deque[SequenceGroup] = deque()
  236. # Sequence groups in the RUNNING state.
  237. # Contain decode requests.
  238. self.running: Deque[SequenceGroup] = deque()
  239. # Sequence groups in the SWAPPED state.
  240. # Contain decode requests that are swapped out.
  241. self.swapped: Deque[SequenceGroup] = deque()
  242. # Time at previous scheduling step
  243. self.prev_time = 0.0
  244. # Did we schedule a prompt at previous step?
  245. self.prev_prompt = False
  246. # Latency of the last prompt step
  247. self.last_prompt_latency = 0.0
  248. # The following field is test-only. It is used to inject artificial
  249. # preemption.
  250. self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
  251. self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
  252. if self.enable_artificial_preemption
  253. else 0)
  254. @property
  255. def lora_enabled(self) -> bool:
  256. return bool(self.lora_config)
  257. @property
  258. def num_decoding_tokens_per_seq(self) -> int:
  259. """The number of new tokens."""
  260. return 1
  261. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  262. # Add sequence groups to the waiting queue.
  263. self.waiting.append(seq_group)
  264. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  265. """Aborts a sequence group with the given ID.
  266. Check if the sequence group with the given ID
  267. is present in any of the state queue.
  268. If present, remove the sequence group from the state queue.
  269. Also, if any of the sequences in the sequence group is not finished,
  270. free the sequence with status `FINISHED_ABORTED`.
  271. Otherwise, do nothing.
  272. Args:
  273. request_id: The ID(s) of the sequence group to abort.
  274. """
  275. if isinstance(request_id, str):
  276. request_id = (request_id, )
  277. request_ids = set(request_id)
  278. for state_queue in [self.waiting, self.running, self.swapped]:
  279. aborted_groups: List[SequenceGroup] = []
  280. for seq_group in state_queue:
  281. if not request_ids:
  282. # Using 'break' here may add two extra iterations,
  283. # but is acceptable to reduce complexity .
  284. break
  285. if seq_group.request_id in request_ids:
  286. # Appending aborted group into pending list.
  287. aborted_groups.append(seq_group)
  288. request_ids.remove(seq_group.request_id)
  289. for aborted_group in aborted_groups:
  290. # Remove the sequence group from the state queue.
  291. state_queue.remove(aborted_group)
  292. for seq in aborted_group.get_seqs():
  293. if seq.is_finished():
  294. continue
  295. seq.status = SequenceStatus.FINISHED_ABORTED
  296. self.free_seq(seq)
  297. def has_unfinished_seqs(self) -> bool:
  298. return len(self.waiting) != 0 or len(self.running) != 0 or len(
  299. self.swapped) != 0
  300. def get_num_unfinished_seq_groups(self) -> int:
  301. return len(self.waiting) + len(self.running) + len(self.swapped)
  302. def _schedule_running(
  303. self,
  304. running_queue: deque,
  305. budget: SchedulingBudget,
  306. curr_loras: Optional[Set[int]],
  307. policy: Policy,
  308. enable_chunking: bool = False,
  309. ) -> Tuple[deque, SchedulerRunningOutputs]:
  310. """Schedule sequence groups that are running.
  311. Running queue should include decode and chunked prefill requests.
  312. Args:
  313. running_queue: The queue that contains running requests (i.e.,
  314. decodes). The given arguments are NOT in-place modified.
  315. budget: The scheduling budget. The argument is in-place updated
  316. when any decodes are preempted.
  317. curr_loras: Currently batched lora request ids. The argument is
  318. in-place updated when any decodes are preempted.
  319. policy: The sorting policy to sort running_queue.
  320. enable_chunking: If True, seq group can be chunked and only a
  321. chunked number of tokens are scheduled if
  322. `budget.num_batched_tokens` has not enough capacity to schedule
  323. all tokens.
  324. Returns:
  325. A tuple of remaining running queue (should be always 0) after
  326. scheduling and SchedulerRunningOutputs.
  327. """
  328. # Blocks that need to be swapped or copied before model execution.
  329. blocks_to_swap_out: Dict[int, int] = {}
  330. blocks_to_copy: Dict[int, List[int]] = {}
  331. decode_seq_groups: List[ScheduledSequenceGroup] = []
  332. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  333. preempted: List[SequenceGroup] = []
  334. swapped_out: List[SequenceGroup] = []
  335. # NOTE: Preemption happens only when there is no available slot
  336. # to keep all the sequence groups in the RUNNING state.
  337. # In this case, the policy is responsible for deciding which sequence
  338. # groups to preempt.
  339. now = time.time()
  340. running_queue = policy.sort_by_priority(now, running_queue)
  341. while running_queue:
  342. seq_group = running_queue[0]
  343. num_running_tokens = self._get_num_new_tokens(
  344. seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
  345. if num_running_tokens == 0:
  346. break
  347. running_queue.popleft()
  348. while not self._can_append_slots(seq_group):
  349. budget.subtract_num_batched_tokens(seq_group.request_id,
  350. num_running_tokens)
  351. num_running_seqs = seq_group.get_max_num_running_seqs()
  352. budget.subtract_num_seqs(seq_group.request_id,
  353. num_running_seqs)
  354. if curr_loras is not None and seq_group.lora_int_id > 0:
  355. curr_loras.remove(seq_group.lora_int_id)
  356. if running_queue:
  357. # Preempt the lowest-priority sequence groups.
  358. victim_seq_group = running_queue.pop()
  359. preempted_mode = self._preempt(victim_seq_group,
  360. blocks_to_swap_out)
  361. if preempted_mode == PreemptionMode.RECOMPUTE:
  362. preempted.append(victim_seq_group)
  363. else:
  364. swapped_out.append(victim_seq_group)
  365. else:
  366. # No other sequence groups can be preempted.
  367. # Preempt the current sequence group.
  368. preempted_mode = self._preempt(seq_group,
  369. blocks_to_swap_out)
  370. if preempted_mode == PreemptionMode.RECOMPUTE:
  371. preempted.append(seq_group)
  372. else:
  373. swapped_out.append(seq_group)
  374. break
  375. else:
  376. self._append_slots(seq_group, blocks_to_copy)
  377. is_prefill = seq_group.is_prefill()
  378. if is_prefill:
  379. prefill_seq_groups.append(
  380. ScheduledSequenceGroup(
  381. seq_group=seq_group,
  382. token_chunk_size=num_running_tokens))
  383. else:
  384. decode_seq_groups.append(
  385. ScheduledSequenceGroup(seq_group=seq_group,
  386. token_chunk_size=1))
  387. budget.add_num_batched_tokens(seq_group.request_id,
  388. num_running_tokens)
  389. # OPTIMIZATION: Note that get_max_num_running_seqs is
  390. # expensive. For the default scheduling chase where
  391. # enable_chunking is False, num_seqs are updated before running
  392. # this method, so we don't have to update it again here.
  393. if enable_chunking:
  394. num_running_seqs = seq_group.get_max_num_running_seqs()
  395. budget.add_num_seqs(seq_group.request_id, num_running_seqs)
  396. if curr_loras is not None and seq_group.lora_int_id > 0:
  397. curr_loras.add(seq_group.lora_int_id)
  398. return running_queue, SchedulerRunningOutputs(
  399. decode_seq_groups=decode_seq_groups,
  400. prefill_seq_groups=prefill_seq_groups,
  401. preempted=preempted,
  402. swapped_out=swapped_out,
  403. blocks_to_swap_out=blocks_to_swap_out,
  404. blocks_to_copy=blocks_to_copy,
  405. num_lookahead_slots=self._get_num_lookahead_slots(
  406. is_prefill=False))
  407. def _schedule_swapped(
  408. self,
  409. swapped_queue: deque,
  410. budget: SchedulingBudget,
  411. curr_loras: Optional[Set[int]],
  412. policy: Policy,
  413. enable_chunking: bool = False,
  414. ) -> Tuple[deque, SchedulerSwappedInOutputs]:
  415. """Schedule sequence groups that are swapped out.
  416. It schedules swapped requests as long as it fits `budget` and
  417. curr_loras <= max_lora from the scheduling config. The input arguments
  418. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  419. Args:
  420. swapped_queue: The queue that contains swapped out requests.
  421. The given arguments are NOT in-place modified.
  422. budget: The scheduling budget. The argument is in-place updated
  423. when any requests are swapped in.
  424. curr_loras: Currently batched lora request ids. The argument is
  425. in-place updated when any requests are swapped in.
  426. policy: The sorting policy to sort swapped_queue.
  427. enable_chunking: If True, seq group can be chunked and only a
  428. chunked number of tokens are scheduled if
  429. `budget.num_batched_tokens` has not enough capacity to schedule
  430. all tokens.
  431. Returns:
  432. A tuple of remaining swapped_queue after scheduling and
  433. SchedulerSwappedInOutputs.
  434. """
  435. # Blocks that need to be swapped or copied before model execution.
  436. blocks_to_swap_in: Dict[int, int] = {}
  437. blocks_to_copy: Dict[int, List[int]] = {}
  438. decode_seq_groups: List[ScheduledSequenceGroup] = []
  439. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  440. now = time.time()
  441. swapped_queue = policy.sort_by_priority(now, swapped_queue)
  442. leftover_swapped: Deque[SequenceGroup] = deque()
  443. while swapped_queue:
  444. seq_group = swapped_queue[0]
  445. # If the sequence group cannot be swapped in, stop.
  446. if not self.block_manager.can_swap_in(seq_group):
  447. break
  448. lora_int_id = 0
  449. if self.lora_enabled:
  450. lora_int_id = seq_group.lora_int_id
  451. assert curr_loras is not None
  452. assert self.lora_config is not None
  453. if (lora_int_id > 0 and (lora_int_id not in curr_loras)
  454. and len(curr_loras) >= self.lora_config.max_loras):
  455. # We don't have a space for another LoRA, so
  456. # we ignore this request for now.
  457. leftover_swapped.appendleft(seq_group)
  458. swapped_queue.popleft()
  459. continue
  460. # The total number of sequences in the RUNNING state should not
  461. # exceed the maximum number of sequences.
  462. num_new_seqs = seq_group.get_max_num_running_seqs()
  463. num_new_tokens = self._get_num_new_tokens(seq_group,
  464. SequenceStatus.SWAPPED,
  465. enable_chunking, budget)
  466. if (num_new_tokens == 0
  467. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  468. num_new_seqs=num_new_seqs)):
  469. break
  470. if lora_int_id > 0 and curr_loras is not None:
  471. curr_loras.add(lora_int_id)
  472. swapped_queue.popleft()
  473. self._swap_in(seq_group, blocks_to_swap_in)
  474. self._append_slots(seq_group, blocks_to_copy)
  475. is_prefill = seq_group.is_prefill()
  476. if is_prefill:
  477. prefill_seq_groups.append(
  478. ScheduledSequenceGroup(seq_group,
  479. token_chunk_size=num_new_tokens))
  480. else:
  481. decode_seq_groups.append(
  482. ScheduledSequenceGroup(seq_group, token_chunk_size=1))
  483. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  484. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  485. swapped_queue.extendleft(leftover_swapped)
  486. return swapped_queue, SchedulerSwappedInOutputs(
  487. decode_seq_groups=decode_seq_groups,
  488. prefill_seq_groups=prefill_seq_groups,
  489. blocks_to_swap_in=blocks_to_swap_in,
  490. blocks_to_copy=blocks_to_copy,
  491. num_lookahead_slots=self._get_num_lookahead_slots(
  492. is_prefill=False))
  493. def _schedule_prefills(
  494. self,
  495. waiting_queue: deque,
  496. budget: SchedulingBudget,
  497. curr_loras: Optional[Set[int]],
  498. enable_chunking: bool = False,
  499. ) -> Tuple[deque, SchedulerPrefillOutputs]:
  500. """Schedule sequence groups that are in prefill stage.
  501. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
  502. as a new prefill (that starts from beginning -> most recently generated
  503. tokens).
  504. It schedules waiting requests as long as it fits `budget` and
  505. curr_loras <= max_lora from the scheduling config. The input arguments
  506. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  507. Args:
  508. waiting_queue: The queue that contains prefill requests.
  509. The given arguments are NOT in-place modified.
  510. budget: The scheduling budget. The argument is in-place updated
  511. when any requests are scheduled.
  512. curr_loras: Currently batched lora request ids. The argument is
  513. in-place updated when any requests are scheduled.
  514. enable_chunking: If True, seq group can be chunked and only a
  515. chunked number of tokens are scheduled if
  516. `budget.num_batched_tokens` has not enough capacity to schedule
  517. all tokens.
  518. Returns:
  519. A tuple of remaining waiting_queue after scheduling and
  520. SchedulerSwappedInOutputs.
  521. """
  522. ignored_seq_groups: List[SequenceGroup] = []
  523. seq_groups: List[SequenceGroup] = []
  524. # We don't sort waiting queue because we assume it is sorted.
  525. # Copy the queue so that the input queue is not modified.
  526. waiting_queue = deque([s for s in waiting_queue])
  527. leftover_waiting_sequences: Deque[SequenceGroup] = deque()
  528. while self._passed_delay(time.time()) and waiting_queue:
  529. seq_group = waiting_queue[0]
  530. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  531. assert len(waiting_seqs) == 1, (
  532. "Waiting sequence group should have only one prompt "
  533. "sequence.")
  534. num_new_tokens = self._get_num_new_tokens(seq_group,
  535. SequenceStatus.WAITING,
  536. enable_chunking, budget)
  537. if not enable_chunking:
  538. num_prompt_tokens = waiting_seqs[0].get_len()
  539. assert num_new_tokens == num_prompt_tokens
  540. if num_new_tokens > self.prompt_limit:
  541. logger.warning(
  542. "Input prompt (%d tokens) is too long"
  543. " and exceeds limit of %d", num_new_tokens,
  544. self.prompt_limit)
  545. for seq in waiting_seqs:
  546. seq.status = SequenceStatus.FINISHED_IGNORED
  547. ignored_seq_groups.append(seq_group)
  548. waiting_queue.popleft()
  549. continue
  550. # If the sequence group cannot be allocated, stop.
  551. can_allocate = self.block_manager.can_allocate(seq_group)
  552. if can_allocate == AllocStatus.LATER:
  553. break
  554. elif can_allocate == AllocStatus.NEVER:
  555. logger.warning(
  556. "Input prompt (%d tokens) is too long"
  557. " and exceeds the capacity of block_manager",
  558. num_new_tokens)
  559. for seq in waiting_seqs:
  560. seq.status = SequenceStatus.FINISHED_IGNORED
  561. ignored_seq_groups.append(seq_group)
  562. waiting_queue.popleft()
  563. continue
  564. lora_int_id = 0
  565. if self.lora_enabled:
  566. lora_int_id = seq_group.lora_int_id
  567. assert curr_loras is not None
  568. assert self.lora_config is not None
  569. if (self.lora_enabled and lora_int_id > 0
  570. and lora_int_id not in curr_loras
  571. and len(curr_loras) >= self.lora_config.max_loras):
  572. # We don't have a space for another LoRA, so
  573. # we ignore this request for now.
  574. leftover_waiting_sequences.appendleft(seq_group)
  575. waiting_queue.popleft()
  576. continue
  577. num_new_seqs = seq_group.get_max_num_running_seqs()
  578. if (num_new_tokens == 0
  579. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  580. num_new_seqs=num_new_seqs)):
  581. break
  582. # Can schedule this request.
  583. if curr_loras is not None and lora_int_id > 0:
  584. curr_loras.add(lora_int_id)
  585. waiting_queue.popleft()
  586. self._allocate_and_set_running(seq_group)
  587. seq_groups.append(
  588. ScheduledSequenceGroup(seq_group=seq_group,
  589. token_chunk_size=num_new_tokens))
  590. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  591. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  592. # Queue requests that couldn't be scheduled.
  593. waiting_queue.extendleft(leftover_waiting_sequences)
  594. if len(seq_groups) > 0:
  595. self.prev_prompt = True
  596. return waiting_queue, SchedulerPrefillOutputs(
  597. seq_groups=seq_groups,
  598. ignored_seq_groups=ignored_seq_groups,
  599. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
  600. def _schedule_default(self) -> SchedulerOutputs:
  601. """Schedule queued requests.
  602. The current policy is designed to optimize the throughput. First,
  603. it batches as many prefill requests as possible. And it schedules
  604. decodes. If there's a pressure on GPU memory, decode requests can
  605. be swapped or preempted.
  606. """
  607. # Include running requests to the budget.
  608. budget = SchedulingBudget(
  609. token_budget=self.scheduler_config.max_num_batched_tokens,
  610. max_num_seqs=self.scheduler_config.max_num_seqs,
  611. )
  612. # Make sure we include num running seqs before scheduling prefill,
  613. # so that we don't schedule beyond max_num_seqs for prefill.
  614. for seq_group in self.running:
  615. budget.add_num_seqs(seq_group.request_id,
  616. seq_group.get_max_num_running_seqs())
  617. curr_loras = set(
  618. seq_group.lora_int_id
  619. for seq_group in self.running) if self.lora_enabled else None
  620. remaining_waiting, prefills = (self.waiting,
  621. SchedulerPrefillOutputs.create_empty())
  622. remaining_running, running_scheduled = (
  623. self.running, SchedulerRunningOutputs.create_empty())
  624. remaining_swapped, swapped_in = (
  625. self.swapped, SchedulerSwappedInOutputs.create_empty())
  626. # If any requests are swapped, prioritized swapped requests.
  627. if not self.swapped:
  628. remaining_waiting, prefills = self._schedule_prefills(
  629. self.waiting, budget, curr_loras, enable_chunking=False)
  630. fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
  631. # Don't schedule decodes if prefills are scheduled.
  632. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
  633. # only contains decode requests, not chunked prefills.
  634. if len(prefills.seq_groups) == 0:
  635. remaining_running, running_scheduled = self._schedule_running(
  636. self.running,
  637. budget,
  638. curr_loras,
  639. fcfs_policy,
  640. enable_chunking=False)
  641. # If any sequence group is preempted, do not swap in any sequence
  642. # group. because it means there's no slot for new running requests.
  643. if len(running_scheduled.preempted) + len(
  644. running_scheduled.swapped_out) == 0:
  645. remaining_swapped, swapped_in = self._schedule_swapped(
  646. self.swapped, budget, curr_loras, fcfs_policy)
  647. assert (budget.num_batched_tokens <=
  648. self.scheduler_config.max_num_batched_tokens)
  649. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  650. # Update waiting requests.
  651. self.waiting = remaining_waiting
  652. self.waiting.extendleft(running_scheduled.preempted)
  653. # Update new running requests.
  654. self.running = remaining_running
  655. self.running.extend([s.seq_group for s in prefills.seq_groups])
  656. self.running.extend(
  657. [s.seq_group for s in running_scheduled.decode_seq_groups])
  658. self.running.extend(
  659. [s.seq_group for s in swapped_in.decode_seq_groups])
  660. # Update swapped requests.
  661. self.swapped = remaining_swapped
  662. self.swapped.extend(running_scheduled.swapped_out)
  663. # There should be no prefill from running queue because this policy
  664. # doesn't allow chunked prefills.
  665. assert len(running_scheduled.prefill_seq_groups) == 0
  666. assert len(swapped_in.prefill_seq_groups) == 0
  667. return SchedulerOutputs(
  668. scheduled_seq_groups=(prefills.seq_groups +
  669. running_scheduled.decode_seq_groups +
  670. swapped_in.decode_seq_groups),
  671. num_prefill_groups=len(prefills.seq_groups),
  672. num_batched_tokens=budget.num_batched_tokens,
  673. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  674. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  675. blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
  676. swapped_in.blocks_to_copy),
  677. ignored_seq_groups=prefills.ignored_seq_groups,
  678. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  679. )
  680. def _schedule_chunked_prefill(self):
  681. """Schedule queued requests.
  682. Chunked prefill allows to chunk prefill requests, batch them together
  683. with decode requests. This policy 1. schedule as many decoding requests
  684. as possible. 2. schedule chunked prefill requests that are not
  685. finished. 3. schedule swapped request. 4. schedule new prefill
  686. requests.
  687. The policy can sustain the high GPU utilization because it can put
  688. prefill and decodes requests to the same batch, while it improves
  689. inter token latency because decodes requests don't need to blocked
  690. by prefill requests.
  691. """
  692. budget = SchedulingBudget(
  693. token_budget=self.scheduler_config.max_num_batched_tokens,
  694. max_num_seqs=self.scheduler_config.max_num_seqs,
  695. )
  696. curr_loras: Set[int] = set()
  697. remaining_waiting, prefills = (self.waiting,
  698. SchedulerPrefillOutputs.create_empty())
  699. remaining_running, running_scheduled = (
  700. self.running, SchedulerRunningOutputs.create_empty())
  701. remaining_swapped, swapped_in = (
  702. self.swapped, SchedulerSwappedInOutputs.create_empty())
  703. # Decoding should be always scheduled first by fcfs.
  704. fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
  705. remaining_running, running_scheduled = self._schedule_running(
  706. self.running,
  707. budget,
  708. curr_loras,
  709. fcfs_policy,
  710. enable_chunking=True)
  711. # Schedule swapped out requests.
  712. # If preemption happens, it means we don't have space for swap-in.
  713. if len(running_scheduled.preempted) + len(
  714. running_scheduled.swapped_out) == 0:
  715. remaining_swapped, swapped_in = self._schedule_swapped(
  716. self.swapped, budget, curr_loras, fcfs_policy)
  717. # Schedule new prefills.
  718. remaining_waiting, prefills = self._schedule_prefills(
  719. self.waiting, budget, curr_loras, enable_chunking=True)
  720. assert (budget.num_batched_tokens <=
  721. self.scheduler_config.max_num_batched_tokens)
  722. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  723. # Update waiting requests.
  724. self.waiting = remaining_waiting
  725. self.waiting.extendleft(running_scheduled.preempted)
  726. # Update new running requests.
  727. self.running = remaining_running
  728. self.running.extend([s.seq_group for s in prefills.seq_groups])
  729. self.running.extend(
  730. [s.seq_group for s in running_scheduled.decode_seq_groups])
  731. self.running.extend(
  732. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  733. self.running.extend(
  734. [s.seq_group for s in swapped_in.decode_seq_groups])
  735. self.running.extend(
  736. [s.seq_group for s in swapped_in.prefill_seq_groups])
  737. # Update swapped requests.
  738. self.swapped = remaining_swapped
  739. self.swapped.extend(running_scheduled.swapped_out)
  740. return SchedulerOutputs(
  741. scheduled_seq_groups=(prefills.seq_groups +
  742. running_scheduled.prefill_seq_groups +
  743. swapped_in.prefill_seq_groups +
  744. running_scheduled.decode_seq_groups +
  745. swapped_in.decode_seq_groups),
  746. num_prefill_groups=(len(prefills.seq_groups) +
  747. len(swapped_in.prefill_seq_groups) +
  748. len(running_scheduled.prefill_seq_groups)),
  749. num_batched_tokens=budget.num_batched_tokens,
  750. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  751. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  752. blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
  753. swapped_in.blocks_to_copy),
  754. ignored_seq_groups=prefills.ignored_seq_groups,
  755. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  756. )
  757. def _schedule(self) -> SchedulerOutputs:
  758. """Schedule queued requests."""
  759. if self.scheduler_config.chunked_prefill_enabled:
  760. return self._schedule_chunked_prefill()
  761. else:
  762. return self._schedule_default()
  763. def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
  764. """Determine whether or not we have enough space in the KV cache to
  765. continue generation of the sequence group.
  766. """# It is True only for testing case to trigger artificial preemption.
  767. if (self.enable_artificial_preemption
  768. and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
  769. and self.artificial_preempt_cnt > 0):
  770. self.artificial_preempt_cnt -= 1
  771. return False
  772. # Appending slots only occurs in decoding.
  773. is_prefill = False
  774. return self.block_manager.can_append_slots(
  775. seq_group=seq_group,
  776. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  777. )
  778. def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
  779. # Swapping in is considered decode.
  780. is_prefill = False
  781. return self.block_manager.can_swap_in(
  782. seq_group=seq_group,
  783. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  784. )
  785. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  786. # Schedule sequence groups.
  787. # This function call changes the internal states of the scheduler
  788. # such as self.running, self.swapped, and self.waiting.
  789. scheduler_outputs = self._schedule()
  790. now = time.time()
  791. # Create input data structures.
  792. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  793. for i, scheduled_seq_group in enumerate(
  794. scheduler_outputs.scheduled_seq_groups):
  795. seq_group = scheduled_seq_group.seq_group
  796. token_chunk_size = scheduled_seq_group.token_chunk_size
  797. seq_group.maybe_set_first_scheduled_time(now)
  798. # seq_id -> SequenceData
  799. seq_data: Dict[int, SequenceData] = {}
  800. # seq_id -> physical block numbers
  801. block_tables: Dict[int, List[int]] = {}
  802. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  803. seq_id = seq.seq_id
  804. seq_data[seq_id] = seq.data
  805. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  806. self.block_manager.access_all_blocks_in_seq(seq, now)
  807. common_computed_block_nums = (
  808. self.block_manager.get_common_computed_block_ids(
  809. seq_group.get_seqs(status=SequenceStatus.RUNNING)))
  810. do_sample = True
  811. if seq_group.is_prefill():
  812. seqs = seq_group.get_seqs()
  813. # Prefill has only 1 sequence.
  814. assert len(seqs) == 1
  815. # In the next iteration, all prompt tokens are not computed.
  816. # It means the prefill is chunked, and we don't need sampling.
  817. # NOTE: We use get_len instead of get_prompt_len because when
  818. # a sequence is preempted, prefill includes previous generated
  819. # output tokens.
  820. if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
  821. seqs[0].data.get_len()):
  822. do_sample = False
  823. # It assumes the scheduled_seq_groups is ordered by
  824. # prefill < decoding.
  825. is_prompt = seq_group.is_prefill()
  826. seq_group_metadata = SequenceGroupMetadata(
  827. request_id=seq_group.request_id,
  828. is_prompt=is_prompt,
  829. seq_data=seq_data,
  830. sampling_params=seq_group.sampling_params,
  831. block_tables=block_tables,
  832. do_sample=do_sample,
  833. token_chunk_size=token_chunk_size,
  834. lora_request=seq_group.lora_request,
  835. computed_block_nums=common_computed_block_nums,
  836. state=seq_group.state,
  837. # `multi_modal_data` will only be present for the 1st comm
  838. # between engine and worker.
  839. # the subsequent comms can still use delta, but
  840. # `multi_modal_data` will be None.
  841. multi_modal_data=seq_group.multi_modal_data
  842. if scheduler_outputs.num_prefill_groups > 0 else None,
  843. )
  844. seq_group_metadata_list.append(seq_group_metadata)
  845. # Now that the batch has been created, we can assume all blocks in the
  846. # batch will have been computed before the next scheduling invocation.
  847. # This is because the engine assumes that a failure in model execution
  848. # will crash the Aphrodite instance / will not retry.
  849. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
  850. self.block_manager.mark_blocks_as_computed(
  851. scheduled_seq_group.seq_group)
  852. return seq_group_metadata_list, scheduler_outputs
  853. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  854. self.block_manager.fork(parent_seq, child_seq)
  855. def free_seq(self, seq: Sequence) -> None:
  856. """Free a sequence from a block table."""
  857. self.block_manager.free(seq)
  858. def free_finished_seq_groups(self) -> None:
  859. self.running = deque(seq_group for seq_group in self.running
  860. if not seq_group.is_finished())
  861. def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
  862. self.block_manager.allocate(seq_group)
  863. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  864. seq.status = SequenceStatus.RUNNING
  865. def _append_slots(
  866. self,
  867. seq_group: SequenceGroup,
  868. blocks_to_copy: Dict[int, List[int]],
  869. ) -> None:
  870. """Appends new slots to the sequences in the given sequence group.
  871. Args:
  872. seq_group (SequenceGroup): The sequence group containing the
  873. sequences to append slots to.
  874. blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
  875. block indices to lists of destination block indices. This
  876. dictionary is updated with the new source and destination block
  877. indices for the appended slots.
  878. """
  879. num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
  880. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  881. cows = self.block_manager.append_slots(seq, num_lookahead_slots)
  882. for src, dests in cows.items():
  883. if src not in blocks_to_copy:
  884. blocks_to_copy[src] = []
  885. blocks_to_copy[src].extend(dests)
  886. def _preempt(
  887. self,
  888. seq_group: SequenceGroup,
  889. blocks_to_swap_out: Dict[int, int],
  890. preemption_mode: Optional[PreemptionMode] = None,
  891. ) -> PreemptionMode:
  892. # If preemption mode is not specified, we determine the mode as follows:
  893. # We use recomputation by default since it incurs lower overhead than
  894. # swapping. However, when the sequence group has multiple sequences
  895. # (e.g., beam search), recomputation is not currently supported. In
  896. # such a case, we use swapping instead.
  897. # FIXME: This makes our scheduling policy a bit bizarre.
  898. # As swapped sequences are prioritized over waiting sequences,
  899. # sequence groups with multiple sequences are implicitly prioritized
  900. # over sequence groups with a single sequence.
  901. # TODO: Support recomputation for sequence groups with multiple
  902. # sequences. This may require a more sophisticated CUDA kernel.
  903. if preemption_mode is None:
  904. if seq_group.get_max_num_running_seqs() == 1:
  905. preemption_mode = PreemptionMode.RECOMPUTE
  906. else:
  907. preemption_mode = PreemptionMode.SWAP
  908. if preemption_mode == PreemptionMode.RECOMPUTE:
  909. self._preempt_by_recompute(seq_group)
  910. elif preemption_mode == PreemptionMode.SWAP:
  911. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  912. else:
  913. raise AssertionError("Invalid preemption mode.")
  914. return preemption_mode
  915. def _preempt_by_recompute(
  916. self,
  917. seq_group: SequenceGroup,
  918. ) -> None:
  919. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  920. assert len(seqs) == 1
  921. for seq in seqs:
  922. seq.status = SequenceStatus.WAITING
  923. self.free_seq(seq)
  924. seq.reset_state_for_recompute()
  925. def _preempt_by_swap(
  926. self,
  927. seq_group: SequenceGroup,
  928. blocks_to_swap_out: Dict[int, int],
  929. ) -> None:
  930. self._swap_out(seq_group, blocks_to_swap_out)
  931. def _swap_in(
  932. self,
  933. seq_group: SequenceGroup,
  934. blocks_to_swap_in: Dict[int, int],
  935. ) -> None:
  936. mapping = self.block_manager.swap_in(seq_group)
  937. blocks_to_swap_in.update(mapping)
  938. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  939. seq.status = SequenceStatus.RUNNING
  940. def _swap_out(
  941. self,
  942. seq_group: SequenceGroup,
  943. blocks_to_swap_out: Dict[int, int],
  944. ) -> None:
  945. if not self.block_manager.can_swap_out(seq_group):
  946. # FIXME: Abort the sequence group instead of aborting the
  947. # entire engine.
  948. raise RuntimeError(
  949. "Aborted due to the lack of CPU swap space. Please increase "
  950. "the swap space to avoid this error.")
  951. mapping = self.block_manager.swap_out(seq_group)
  952. blocks_to_swap_out.update(mapping)
  953. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  954. seq.status = SequenceStatus.SWAPPED
  955. def _passed_delay(self, now: float) -> bool:
  956. if self.prev_prompt:
  957. self.last_prompt_latency = now - self.prev_time
  958. self.prev_time, self.prev_prompt = now, False
  959. # Delay scheduling prompts to let waiting queue fill up
  960. if self.scheduler_config.delay_factor > 0 and self.waiting:
  961. earliest_arrival_time = min(
  962. [e.metrics.arrival_time for e in self.waiting])
  963. passed_delay = (
  964. (now - earliest_arrival_time) >
  965. (self.scheduler_config.delay_factor * self.last_prompt_latency)
  966. or not self.running)
  967. else:
  968. passed_delay = True
  969. return passed_delay
  970. def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
  971. """The number of slots to allocate per sequence per step, beyond known
  972. token ids. Speculative decoding uses these slots to store KV activations
  973. of tokens which may or may not be accepted.
  974. Speculative decoding does not yet support prefill, so we do not perform
  975. lookahead allocation for prefill.
  976. """
  977. if is_prefill:
  978. return 0
  979. return self.scheduler_config.num_lookahead_slots
  980. def _get_num_new_tokens(self, seq_group: SequenceGroup,
  981. status: SequenceStatus, enable_chunking: bool,
  982. budget: SchedulingBudget) -> int:
  983. """Get the next new tokens to compute for a given sequence group
  984. that's in a given `status`.
  985. The API could chunk the number of tokens to compute based on `budget`
  986. if `enable_chunking` is True. If a sequence group has multiple
  987. sequences (e.g., running beam search), it means it is in decoding
  988. phase, so chunking doesn't happen.
  989. Returns 0 if the new token cannot be computed due to token budget.
  990. """
  991. num_new_tokens = 0
  992. seqs = seq_group.get_seqs(status=status)
  993. for seq in seqs:
  994. num_new_tokens += seq.get_num_new_tokens()
  995. assert num_new_tokens > 0
  996. # Chunk if a running request cannot fit in.
  997. # If number of seq > 1, it means it is doing beam search in a
  998. # decode phase. Do not chunk in that case.
  999. if enable_chunking and len(seqs) == 1:
  1000. num_new_tokens = min(num_new_tokens,
  1001. budget.remaining_token_budget())
  1002. return num_new_tokens