scheduler.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106
  1. import enum
  2. import time
  3. from collections import deque
  4. from dataclasses import dataclass, field
  5. from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
  6. from loguru import logger
  7. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  8. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  9. from aphrodite.processing.policy import Policy, PolicyFactory
  10. from aphrodite.lora.request import LoRARequest
  11. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  12. SequenceGroupMetadata, SequenceStatus)
  13. from aphrodite.common.utils import merge_dicts
  14. class PreemptionMode(enum.Enum):
  15. """Preemption modes.
  16. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  17. and swap them back in when the sequences are resumed.
  18. 2. Recomputation: Discard the blocks of the preempted sequences and
  19. recompute them when the sequences are resumed, treating the sequences as
  20. new prompts.
  21. """
  22. SWAP = enum.auto()
  23. RECOMPUTE = enum.auto()
  24. @dataclass
  25. class SchedulingBudget:
  26. """The available slots for scheduling.
  27. TODO: Right now, the budget is request_id-aware meaning it can ignore
  28. budget update from the same request_id. It is because in normal scheduling
  29. path, we update RUNNING num_seqs ahead of time, meaning it could be
  30. updated more than once when scheduling RUNNING requests. Since this won't
  31. happen if we only have chunked prefill scheduling, we can remove this
  32. feature from the API when chunked prefill is enabled by default.
  33. """
  34. token_budget: int
  35. max_num_seqs: int
  36. _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
  37. _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
  38. _num_batched_tokens: int = 0
  39. _num_curr_seqs: int = 0
  40. def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
  41. assert num_new_tokens != 0
  42. assert num_new_seqs != 0
  43. return (self.num_batched_tokens + num_new_tokens <= self.token_budget
  44. and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
  45. def remaining_token_budget(self):
  46. return self.token_budget - self.num_batched_tokens
  47. def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
  48. if req_id in self._requeset_ids_num_batched_tokens:
  49. return
  50. self._requeset_ids_num_batched_tokens.add(req_id)
  51. self._num_batched_tokens += num_batched_tokens
  52. def subtract_num_batched_tokens(self, req_id: str,
  53. num_batched_tokens: int):
  54. if req_id in self._requeset_ids_num_batched_tokens:
  55. self._requeset_ids_num_batched_tokens.remove(req_id)
  56. self._num_batched_tokens -= num_batched_tokens
  57. def add_num_seqs(self, req_id: str, num_curr_seqs: int):
  58. if req_id in self._requeset_ids_num_curr_seqs:
  59. return
  60. self._requeset_ids_num_curr_seqs.add(req_id)
  61. self._num_curr_seqs += num_curr_seqs
  62. def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
  63. if req_id in self._requeset_ids_num_curr_seqs:
  64. self._requeset_ids_num_curr_seqs.remove(req_id)
  65. self._num_curr_seqs -= num_curr_seqs
  66. @property
  67. def num_batched_tokens(self):
  68. return self._num_batched_tokens
  69. @property
  70. def num_curr_seqs(self):
  71. return self._num_curr_seqs
  72. @dataclass
  73. class ScheduledSequenceGroup:
  74. # A sequence group that's scheduled.
  75. seq_group: SequenceGroup
  76. # The total chunk size (number of tokens) to process for next iteration.
  77. # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
  78. # chunked, it can be smaller than that.
  79. token_chunk_size: int
  80. @dataclass
  81. class SchedulerOutputs:
  82. """The scheduling decision made from a scheduler."""
  83. # Scheduled sequence groups.
  84. scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
  85. # Number of prefill groups scheduled.
  86. num_prefill_groups: int
  87. # Total number of batched tokens.
  88. num_batched_tokens: int
  89. # Blocks to swap in. Dict of CPU -> GPU block number.
  90. blocks_to_swap_in: Dict[int, int]
  91. # Blocks to swap out. Dict of GPU -> CPU block number.
  92. blocks_to_swap_out: Dict[int, int]
  93. # Blocks to copy. Source to a list of dest blocks.
  94. blocks_to_copy: Dict[int, List[int]]
  95. # Sequence groups that are going to be ignored.
  96. ignored_seq_groups: List[SequenceGroup]
  97. # The number of slots for lookahead decoding.
  98. num_lookahead_slots: int
  99. def __post_init__(self):
  100. # Swap in and swap out should never happen at the same time.
  101. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
  102. self.num_loras: int = len(self.lora_requests)
  103. if self.num_loras > 0:
  104. self._sort_by_lora_ids()
  105. def is_empty(self) -> bool:
  106. # NOTE: We do not consider the ignored sequence groups.
  107. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  108. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  109. def _sort_by_lora_ids(self) -> bool:
  110. self.scheduled_seq_groups = sorted(
  111. self.scheduled_seq_groups,
  112. key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
  113. @property
  114. def lora_requests(self) -> Set[LoRARequest]:
  115. return {
  116. g.seq_group.lora_request
  117. for g in self.scheduled_seq_groups
  118. if g.seq_group.lora_request is not None
  119. }
  120. @dataclass
  121. class SchedulerRunningOutputs:
  122. """The requests that are scheduled from a running queue.
  123. Could contain prefill (prefill that's chunked) or decodes. If there's not
  124. enough memory, it can be preempted (for recompute) or swapped out.
  125. """
  126. # Selected sequences that are running and in a decoding phase.
  127. decode_seq_groups: List[SequenceGroup]
  128. # Selected sequences that are running and in a prefill phase.
  129. # I.e., it means the prefill has been chunked.
  130. prefill_seq_groups: List[SequenceGroup]
  131. # The preempted sequences.
  132. preempted: List[SequenceGroup]
  133. # Sequences that are swapped out.
  134. swapped_out: List[SequenceGroup]
  135. # The blocks to swap out.
  136. blocks_to_swap_out: Dict[int, int]
  137. # The blocks to copy.
  138. blocks_to_copy: Dict[int, List[int]]
  139. # The number of slots for lookahead decoding.
  140. num_lookahead_slots: int
  141. @classmethod
  142. def create_empty(cls) -> "SchedulerRunningOutputs":
  143. return SchedulerRunningOutputs(
  144. decode_seq_groups=[],
  145. prefill_seq_groups=[],
  146. preempted=[],
  147. swapped_out=[],
  148. blocks_to_swap_out={},
  149. blocks_to_copy={},
  150. num_lookahead_slots=0,
  151. )
  152. @dataclass
  153. class SchedulerSwappedInOutputs:
  154. """The requests that are scheduled from a swap queue.
  155. Could contain prefill (prefill that's chunked) or decodes.
  156. """
  157. # Selected sequences that are going to be swapped in and is in a
  158. # decoding phase.
  159. decode_seq_groups: List[SequenceGroup]
  160. # Selected sequences that are going to be swapped in and in a prefill
  161. # phase. I.e., it means the prefill has been chunked.
  162. prefill_seq_groups: List[SequenceGroup]
  163. # The blocks to swap in.
  164. blocks_to_swap_in: Dict[int, int]
  165. # The blocks to copy.
  166. blocks_to_copy: Dict[int, List[int]]
  167. # The number of slots for lookahead decoding.
  168. num_lookahead_slots: int
  169. @classmethod
  170. def create_empty(cls) -> "SchedulerSwappedInOutputs":
  171. return SchedulerSwappedInOutputs(
  172. decode_seq_groups=[],
  173. prefill_seq_groups=[],
  174. blocks_to_swap_in={},
  175. blocks_to_copy={},
  176. num_lookahead_slots=0,
  177. )
  178. @dataclass
  179. class SchedulerPrefillOutputs:
  180. """The requests that are scheduled from a waiting queue.
  181. Could contain a fresh prefill requests or preempted requests that need
  182. to be recomputed from scratch.
  183. """
  184. # Selected sequences for prefill.
  185. seq_groups: List[SequenceGroup]
  186. # Ignored sequence groups.
  187. ignored_seq_groups: List[SequenceGroup]
  188. num_lookahead_slots: int
  189. @classmethod
  190. def create_empty(cls) -> "SchedulerPrefillOutputs":
  191. return SchedulerPrefillOutputs(
  192. seq_groups=[],
  193. ignored_seq_groups=[],
  194. num_lookahead_slots=0,
  195. )
  196. class Scheduler:
  197. def __init__(
  198. self,
  199. scheduler_config: SchedulerConfig,
  200. cache_config: CacheConfig,
  201. lora_config: Optional[LoRAConfig],
  202. ) -> None:
  203. self.scheduler_config = scheduler_config
  204. self.cache_config = cache_config
  205. # Note for LoRA scheduling: the current policy is extremely
  206. # simple and NOT fair. It can lead to starvation of some
  207. # LoRAs. This should be improved in the future.
  208. self.lora_config = lora_config
  209. if self.scheduler_config.chunked_prefill_enabled:
  210. self.prompt_limit = self.scheduler_config.max_model_len
  211. else:
  212. self.prompt_limit = min(
  213. self.scheduler_config.max_model_len,
  214. self.scheduler_config.max_num_batched_tokens)
  215. BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
  216. version="v2" if self.scheduler_config.
  217. use_v2_block_manager else "v1")
  218. # Create the block space manager.
  219. self.block_manager = BlockSpaceManagerImpl(
  220. block_size=self.cache_config.block_size,
  221. num_gpu_blocks=self.cache_config.num_gpu_blocks,
  222. num_cpu_blocks=self.cache_config.num_cpu_blocks,
  223. sliding_window=self.cache_config.sliding_window,
  224. enable_caching=self.cache_config.context_shift)
  225. # Sequence groups in the WAITING state.
  226. # Contain new prefill or preempted requests.
  227. self.waiting: Deque[SequenceGroup] = deque()
  228. # Sequence groups in the RUNNING state.
  229. # Contain decode requests.
  230. self.running: Deque[SequenceGroup] = deque()
  231. # Sequence groups in the SWAPPED state.
  232. # Contain decode requests that are swapped out.
  233. self.swapped: Deque[SequenceGroup] = deque()
  234. # Time at previous scheduling step
  235. self.prev_time = 0.0
  236. # Did we schedule a prompt at previous step?
  237. self.prev_prompt = False
  238. # Latency of the last prompt step
  239. self.last_prompt_latency = 0.0
  240. @property
  241. def lora_enabled(self) -> bool:
  242. return bool(self.lora_config)
  243. @property
  244. def num_decoding_tokens_per_seq(self) -> int:
  245. """The number of new tokens."""
  246. return 1
  247. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  248. # Add sequence groups to the waiting queue.
  249. logger.debug(f"add_seq_group {seq_group.request_id}")
  250. self.waiting.append(seq_group)
  251. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  252. """Aborts a sequence group with the given ID.
  253. Check if the sequence group with the given ID
  254. is present in any of the state queue.
  255. If present, remove the sequence group from the state queue.
  256. Also, if any of the sequences in the sequence group is not finished,
  257. free the sequence with status `FINISHED_ABORTED`.
  258. Otherwise, do nothing.
  259. Args:
  260. request_id: The ID(s) of the sequence group to abort.
  261. """
  262. if isinstance(request_id, str):
  263. request_id = (request_id, )
  264. request_ids = set(request_id)
  265. for state_queue in [self.waiting, self.running, self.swapped]:
  266. aborted_groups: List[SequenceGroup] = []
  267. for seq_group in state_queue:
  268. if not request_ids:
  269. # Using 'break' here may add two extra iterations,
  270. # but is acceptable to reduce complexity .
  271. break
  272. if seq_group.request_id in request_ids:
  273. # Appending aborted group into pending list.
  274. aborted_groups.append(seq_group)
  275. request_ids.remove(seq_group.request_id)
  276. for aborted_group in aborted_groups:
  277. # Remove the sequence group from the state queue.
  278. state_queue.remove(aborted_group)
  279. for seq in aborted_group.get_seqs():
  280. if seq.is_finished():
  281. continue
  282. seq.status = SequenceStatus.FINISHED_ABORTED
  283. self.free_seq(seq)
  284. def has_unfinished_seqs(self) -> bool:
  285. return self.waiting or self.running or self.swapped
  286. def get_num_unfinished_seq_groups(self) -> int:
  287. return len(self.waiting) + len(self.running) + len(self.swapped)
  288. def _schedule_running(
  289. self,
  290. running_queue: deque,
  291. budget: SchedulingBudget,
  292. curr_loras: Optional[Set[int]],
  293. policy: Policy,
  294. enable_chunking: bool = False,
  295. ) -> Tuple[deque, SchedulerRunningOutputs]:
  296. """Schedule sequence groups that are running.
  297. Running queue should include decode and chunked prefill requests.
  298. Args:
  299. running_queue: The queue that contains running requests (i.e.,
  300. decodes). The given arguments are NOT in-place modified.
  301. budget: The scheduling budget. The argument is in-place updated
  302. when any decodes are preempted.
  303. curr_loras: Currently batched lora request ids. The argument is
  304. in-place updated when any decodes are preempted.
  305. policy: The sorting policy to sort running_queue.
  306. enable_chunking: If True, seq group can be chunked and only a
  307. chunked number of tokens are scheduled if
  308. `budget.num_batched_tokens` has not enough capacity to schedule
  309. all tokens.
  310. Returns:
  311. A tuple of remaining running queue (should be always 0) after
  312. scheduling and SchedulerRunningOutputs
  313. """
  314. # Blocks that need to be swapped or copied before model execution.
  315. blocks_to_swap_out: Dict[int, int] = {}
  316. blocks_to_copy: Dict[int, List[int]] = {}
  317. decode_seq_groups: List[ScheduledSequenceGroup] = []
  318. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  319. preempted: List[SequenceGroup] = []
  320. swapped_out: List[SequenceGroup] = []
  321. # NOTE: Preemption happens only when there is no available slot
  322. # to keep all the sequence groups in the RUNNING state.
  323. # In this case, the policy is responsible for deciding which sequence
  324. # groups to preempt.
  325. now = time.time()
  326. running_queue = policy.sort_by_priority(now, running_queue)
  327. while running_queue:
  328. seq_group = running_queue[0]
  329. num_running_tokens = self._get_num_new_tokens(
  330. seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
  331. # We can have up to 1 running prefill at any given time in running
  332. # queue, which means we can guarantee chunk size is at least 1.
  333. assert num_running_tokens != 0
  334. num_running_seqs = seq_group.get_max_num_running_seqs()
  335. running_queue.popleft()
  336. while not self._can_append_slots(seq_group):
  337. budget.subtract_num_batched_tokens(seq_group.request_id,
  338. num_running_tokens)
  339. budget.subtract_num_seqs(seq_group.request_id,
  340. num_running_seqs)
  341. if curr_loras is not None and seq_group.lora_int_id > 0:
  342. curr_loras.pop(seq_group.lora_int_id)
  343. if running_queue:
  344. # Preempt the lowest-priority sequence groups.
  345. victim_seq_group = running_queue.pop()
  346. preempted_mode = self._preempt(victim_seq_group,
  347. blocks_to_swap_out)
  348. if preempted_mode == PreemptionMode.RECOMPUTE:
  349. preempted.append(victim_seq_group)
  350. else:
  351. swapped_out.append(victim_seq_group)
  352. else:
  353. # No other sequence groups can be preempted.
  354. # Preempt the current sequence group.
  355. preempted_mode = self._preempt(seq_group,
  356. blocks_to_swap_out)
  357. if preempted_mode == PreemptionMode.RECOMPUTE:
  358. preempted.append(seq_group)
  359. else:
  360. swapped_out.append(seq_group)
  361. break
  362. else:
  363. logger.debug(f"append slot for {seq_group}")
  364. self._append_slots(seq_group, blocks_to_copy)
  365. is_prefill = seq_group.is_prefill()
  366. if is_prefill:
  367. prefill_seq_groups.append(
  368. ScheduledSequenceGroup(
  369. seq_group=seq_group,
  370. token_chunk_size=num_running_tokens))
  371. else:
  372. decode_seq_groups.append(
  373. ScheduledSequenceGroup(seq_group=seq_group,
  374. token_chunk_size=1))
  375. budget.add_num_batched_tokens(seq_group.request_id,
  376. num_running_tokens)
  377. budget.add_num_seqs(seq_group.request_id, num_running_seqs)
  378. if curr_loras is not None and seq_group.lora_int_id > 0:
  379. curr_loras.add(seq_group.lora_int_id)
  380. # Make sure all queues are updated.
  381. assert len(running_queue) == 0
  382. return running_queue, SchedulerRunningOutputs(
  383. decode_seq_groups=decode_seq_groups,
  384. prefill_seq_groups=prefill_seq_groups,
  385. preempted=preempted,
  386. swapped_out=swapped_out,
  387. blocks_to_swap_out=blocks_to_swap_out,
  388. blocks_to_copy=blocks_to_copy,
  389. num_lookahead_slots=self._get_num_lookahead_slots(
  390. is_prefill=False))
  391. def _schedule_swapped(
  392. self,
  393. swapped_queue: deque,
  394. budget: SchedulingBudget,
  395. curr_loras: Optional[Set[int]],
  396. policy: Policy,
  397. enable_chunking: bool = False,
  398. ) -> Tuple[deque, SchedulerSwappedInOutputs]:
  399. """Schedule sequence groups that are swapped out.
  400. It schedules swapped requests as long as it fits `budget` and
  401. curr_loras <= max_lora from the scheduling config. The input arguments
  402. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  403. Args:
  404. swapped_queue: The queue that contains swapped out requests.
  405. The given arguments are NOT in-place modified.
  406. budget: The scheduling budget. The argument is in-place updated
  407. when any requests are swapped in.
  408. curr_loras: Currently batched lora request ids. The argument is
  409. in-place updated when any requests are swapped in.
  410. policy: The sorting policy to sort swapped_queue.
  411. enable_chunking: If True, seq group can be chunked and only a
  412. chunked number of tokens are scheduled if
  413. `budget.num_batched_tokens` has not enough capacity to schedule
  414. all tokens.
  415. Returns:
  416. A tuple of remaining swapped_queue after scheduling and
  417. SchedulerSwappedInOutputs.
  418. """
  419. # Blocks that need to be swapped or copied before model execution.
  420. blocks_to_swap_in: Dict[int, int] = {}
  421. blocks_to_copy: Dict[int, List[int]] = {}
  422. decode_seq_groups: List[ScheduledSequenceGroup] = []
  423. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  424. now = time.time()
  425. swapped_queue = policy.sort_by_priority(now, swapped_queue)
  426. leftover_swapped = deque()
  427. while swapped_queue:
  428. seq_group = swapped_queue[0]
  429. # If the sequence group cannot be swapped in, stop.
  430. if not self.block_manager.can_swap_in(seq_group):
  431. break
  432. lora_int_id = 0
  433. if self.lora_enabled:
  434. lora_int_id = seq_group.lora_int_id
  435. if (lora_int_id > 0 and lora_int_id not in curr_loras
  436. and len(curr_loras) >= self.lora_config.max_loras):
  437. # We don't have a space for another LoRA, so
  438. # we ignore this request for now.
  439. leftover_swapped.appendleft(seq_group)
  440. swapped_queue.popleft()
  441. continue
  442. # The total number of sequences in the RUNNING state should not
  443. # exceed the maximum number of sequences.
  444. num_new_seqs = seq_group.get_max_num_running_seqs()
  445. num_new_tokens = self._get_num_new_tokens(seq_group,
  446. SequenceStatus.SWAPPED,
  447. enable_chunking, budget)
  448. if (num_new_tokens == 0
  449. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  450. num_new_seqs=num_new_seqs)):
  451. break
  452. if lora_int_id > 0 and curr_loras is not None:
  453. curr_loras.add(lora_int_id)
  454. swapped_queue.popleft()
  455. self._swap_in(seq_group, blocks_to_swap_in)
  456. self._append_slots(seq_group, blocks_to_copy)
  457. is_prefill = seq_group.is_prefill()
  458. if is_prefill:
  459. prefill_seq_groups.append(
  460. ScheduledSequenceGroup(seq_group,
  461. token_chunk_size=num_new_tokens))
  462. else:
  463. assert num_new_tokens == 1
  464. decode_seq_groups.append(
  465. ScheduledSequenceGroup(seq_group, token_chunk_size=1))
  466. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  467. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  468. swapped_queue.extendleft(leftover_swapped)
  469. return swapped_queue, SchedulerSwappedInOutputs(
  470. decode_seq_groups=decode_seq_groups,
  471. prefill_seq_groups=prefill_seq_groups,
  472. blocks_to_swap_in=blocks_to_swap_in,
  473. blocks_to_copy=blocks_to_copy,
  474. num_lookahead_slots=self._get_num_lookahead_slots(
  475. is_prefill=False))
  476. def _schedule_prefills(
  477. self,
  478. waiting_queue: deque,
  479. budget: SchedulingBudget,
  480. curr_loras: Optional[Set[int]],
  481. enable_chunking: bool = False,
  482. ) -> Tuple[deque, SchedulerPrefillOutputs]:
  483. """Schedule sequence groups that are in prefill stage.
  484. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
  485. as a new prefill (that starts from beginning -> most recently generated
  486. tokens).
  487. It schedules waiting requests as long as it fits `budget` and
  488. curr_loras <= max_lora from the scheduling config. The input arguments
  489. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  490. Args:
  491. waiting_queue: The queue that contains prefill requests.
  492. The given arguments are NOT in-place modified.
  493. budget: The scheduling budget. The argument is in-place updated
  494. when any requests are scheduled.
  495. curr_loras: Currently batched lora request ids. The argument is
  496. in-place updated when any requests are scheduled.
  497. enable_chunking: If True, seq group can be chunked and only a
  498. chunked number of tokens are scheduled if
  499. `budget.num_batched_tokens` has not enough capacity to schedule
  500. all tokens.
  501. Returns:
  502. A tuple of remaining waiting_queue after scheduling and
  503. SchedulerSwappedInOutputs.
  504. """
  505. ignored_seq_groups: List[SequenceGroup] = []
  506. seq_groups: List[SequenceGroup] = []
  507. # We don't sort waiting queue because we assume it is sorted.
  508. # Copy the queue so that the input queue is not modified.
  509. waiting_queue = deque([s for s in waiting_queue])
  510. leftover_waiting_sequences = deque()
  511. while self._passed_delay(time.time()) and waiting_queue:
  512. seq_group = waiting_queue[0]
  513. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  514. assert len(waiting_seqs) == 1, (
  515. "Waiting sequence group should have only one prompt "
  516. "sequence.")
  517. num_new_tokens = self._get_num_new_tokens(seq_group,
  518. SequenceStatus.WAITING,
  519. enable_chunking, budget)
  520. if not enable_chunking:
  521. num_prompt_tokens = waiting_seqs[0].get_len()
  522. assert num_new_tokens == num_prompt_tokens
  523. if num_new_tokens > self.prompt_limit:
  524. logger.warning(
  525. f"Input prompt ({num_new_tokens} tokens) is too long"
  526. f" and exceeds limit of {self.prompt_limit}")
  527. for seq in waiting_seqs:
  528. seq.status = SequenceStatus.FINISHED_IGNORED
  529. ignored_seq_groups.append(seq_group)
  530. waiting_queue.popleft()
  531. continue
  532. # If the sequence group cannot be allocated, stop.
  533. can_allocate = self.block_manager.can_allocate(seq_group)
  534. if can_allocate == AllocStatus.LATER:
  535. break
  536. elif can_allocate == AllocStatus.NEVER:
  537. logger.warning(
  538. f"Input prompt ({num_new_tokens} tokens) is too long"
  539. f" and exceeds the capacity of block_manager")
  540. for seq in waiting_seqs:
  541. seq.status = SequenceStatus.FINISHED_IGNORED
  542. ignored_seq_groups.append(seq_group)
  543. waiting_queue.popleft()
  544. continue
  545. lora_int_id = 0
  546. if self.lora_enabled:
  547. lora_int_id = seq_group.lora_int_id
  548. if (self.lora_enabled and lora_int_id > 0
  549. and lora_int_id not in curr_loras
  550. and len(curr_loras) >= self.lora_config.max_loras):
  551. # We don't have a space for another LoRA, so
  552. # we ignore this request for now.
  553. leftover_waiting_sequences.appendleft(seq_group)
  554. waiting_queue.popleft()
  555. continue
  556. num_new_seqs = seq_group.get_max_num_running_seqs()
  557. if (num_new_tokens == 0
  558. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  559. num_new_seqs=num_new_seqs)):
  560. break
  561. # Can schedule this request.
  562. if curr_loras is not None and lora_int_id > 0:
  563. curr_loras.add(lora_int_id)
  564. waiting_queue.popleft()
  565. self._allocate_and_set_running(seq_group, num_new_tokens)
  566. seq_groups.append(
  567. ScheduledSequenceGroup(seq_group=seq_group,
  568. token_chunk_size=num_new_tokens))
  569. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  570. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  571. # Queue requests that couldn't be scheduled.
  572. waiting_queue.extendleft(leftover_waiting_sequences)
  573. if len(seq_groups) > 0:
  574. self.prev_prompt = True
  575. return waiting_queue, SchedulerPrefillOutputs(
  576. seq_groups=seq_groups,
  577. ignored_seq_groups=ignored_seq_groups,
  578. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
  579. def _schedule_default(self) -> SchedulerOutputs:
  580. """Schedule queued requests.
  581. The current policy is designed to opimimize the throughput. First,
  582. it batches as many prefill requests as possible. And it schedules
  583. decodes. If there's a pressure on GPU memory, decode requests can
  584. be swapped or preempted.
  585. """
  586. # Include running requests to the budget.
  587. budget = SchedulingBudget(
  588. token_budget=self.scheduler_config.max_num_batched_tokens,
  589. max_num_seqs=self.scheduler_config.max_num_seqs,
  590. )
  591. # Make sure we include num running seqs before scheduling prefill,
  592. # so that we don't schedule beyond max_num_seqs for prefill.
  593. for seq_group in self.running:
  594. budget.add_num_seqs(seq_group.request_id,
  595. seq_group.get_max_num_running_seqs())
  596. curr_loras = set(
  597. seq_group.lora_int_id
  598. for seq_group in self.running) if self.lora_enabled else None
  599. remaining_waiting, prefills = (self.waiting,
  600. SchedulerPrefillOutputs.create_empty())
  601. remaining_running, running_scheduled = (
  602. self.running, SchedulerRunningOutputs.create_empty())
  603. remaining_swapped, swapped_in = (
  604. self.swapped, SchedulerSwappedInOutputs.create_empty())
  605. # If any requests are swapped, prioritized swapped requests.
  606. if not self.swapped:
  607. remaining_waiting, prefills = self._schedule_prefills(
  608. self.waiting, budget, curr_loras, enable_chunking=False)
  609. fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
  610. # Don't schedule decodes if prefills are scheduled.
  611. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
  612. # only contains decode requests, not chunked prefills.
  613. if len(prefills.seq_groups) == 0:
  614. remaining_running, running_scheduled = self._schedule_running(
  615. self.running,
  616. budget,
  617. curr_loras,
  618. fcfs_policy,
  619. enable_chunking=False)
  620. # If any sequence group is preempted, do not swap in any sequence
  621. # group. because it means there's no slot for new running requests.
  622. if len(running_scheduled.preempted) + len(
  623. running_scheduled.swapped_out) == 0:
  624. remaining_swapped, swapped_in = self._schedule_swapped(
  625. self.swapped, budget, curr_loras, fcfs_policy)
  626. assert (budget.num_batched_tokens <=
  627. self.scheduler_config.max_num_batched_tokens)
  628. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  629. # Update waiting requests.
  630. self.waiting = remaining_waiting
  631. self.waiting.extendleft(running_scheduled.preempted)
  632. # Update new running requests.
  633. self.running = remaining_running
  634. self.running.extend([s.seq_group for s in prefills.seq_groups])
  635. self.running.extend(
  636. [s.seq_group for s in running_scheduled.decode_seq_groups])
  637. self.running.extend(
  638. [s.seq_group for s in swapped_in.decode_seq_groups])
  639. # Update swapped requests.
  640. self.swapped = remaining_swapped
  641. self.swapped.extend(running_scheduled.swapped_out)
  642. # There should be no prefill from running queue because this policy
  643. # doesn't allow chunked prefills.
  644. assert len(running_scheduled.prefill_seq_groups) == 0
  645. assert len(swapped_in.prefill_seq_groups) == 0
  646. return SchedulerOutputs(
  647. scheduled_seq_groups=(prefills.seq_groups +
  648. running_scheduled.decode_seq_groups +
  649. swapped_in.decode_seq_groups),
  650. num_prefill_groups=len(prefills.seq_groups),
  651. num_batched_tokens=budget.num_batched_tokens,
  652. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  653. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  654. blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
  655. swapped_in.blocks_to_copy),
  656. ignored_seq_groups=prefills.ignored_seq_groups,
  657. num_lookahead_slots=(prefills.num_lookahead_slots +
  658. running_scheduled.num_lookahead_slots +
  659. swapped_in.num_lookahead_slots),
  660. )
  661. def _schedule_chunked_prefill(self):
  662. """Schedule queued requests.
  663. Chunked prefill allows to chunk prefill requests, batch them together
  664. with decode requests. This policy 1. schedule as many decoding requests
  665. as possible. 2. schedule chunked prefill requests that are not
  666. finished. 3. schedule swapped request. 4. schedule new prefill
  667. requests.
  668. The policy can sustain the high GPU utilization because it can put
  669. prefill and decodes requests to the same batch, while it improves
  670. inter token latency because decodes requests don't need to blocked
  671. by prefill requests.
  672. """
  673. budget = SchedulingBudget(
  674. token_budget=self.scheduler_config.max_num_batched_tokens,
  675. max_num_seqs=self.scheduler_config.max_num_seqs,
  676. )
  677. curr_loras = set()
  678. remaining_waiting, prefills = (self.waiting,
  679. SchedulerPrefillOutputs.create_empty())
  680. remaining_running, running_scheduled = (
  681. self.running, SchedulerRunningOutputs.create_empty())
  682. remaining_swapped, swapped_in = (
  683. self.swapped, SchedulerSwappedInOutputs.create_empty())
  684. # Decoding should be always scheduled first by fcfs.
  685. fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
  686. remaining_running, running_scheduled = self._schedule_running(
  687. self.running,
  688. budget,
  689. curr_loras,
  690. fcfs_policy,
  691. enable_chunking=True)
  692. # Schedule swapped out requests.
  693. # If preemption happens, it means we don't have space for swap-in.
  694. if len(running_scheduled.preempted) + len(
  695. running_scheduled.swapped_out) == 0:
  696. remaining_swapped, swapped_in = self._schedule_swapped(
  697. self.swapped, budget, curr_loras, fcfs_policy)
  698. # Schedule new prefills.
  699. remaining_waiting, prefills = self._schedule_prefills(
  700. self.waiting, budget, curr_loras, enable_chunking=True)
  701. assert (budget.num_batched_tokens <=
  702. self.scheduler_config.max_num_batched_tokens)
  703. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  704. # Update waiting requests.
  705. self.waiting = remaining_waiting
  706. self.waiting.extendleft(running_scheduled.preempted)
  707. # Update new running requests.
  708. self.running = remaining_running
  709. self.running.extend([s.seq_group for s in prefills.seq_groups])
  710. self.running.extend(
  711. [s.seq_group for s in running_scheduled.decode_seq_groups])
  712. self.running.extend(
  713. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  714. self.running.extend(
  715. [s.seq_group for s in swapped_in.decode_seq_groups])
  716. self.running.extend(
  717. [s.seq_group for s in swapped_in.prefill_seq_groups])
  718. # Update swapped requests.
  719. self.swapped = remaining_swapped
  720. self.swapped.extend(running_scheduled.swapped_out)
  721. return SchedulerOutputs(
  722. scheduled_seq_groups=(prefills.seq_groups +
  723. running_scheduled.prefill_seq_groups +
  724. swapped_in.prefill_seq_groups +
  725. running_scheduled.decode_seq_groups +
  726. swapped_in.decode_seq_groups),
  727. num_prefill_groups=(len(prefills.seq_groups) +
  728. len(swapped_in.prefill_seq_groups) +
  729. len(running_scheduled.prefill_seq_groups)),
  730. num_batched_tokens=budget.num_batched_tokens,
  731. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  732. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  733. blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
  734. swapped_in.blocks_to_copy),
  735. ignored_seq_groups=prefills.ignored_seq_groups,
  736. num_lookahead_slots=(prefills.num_lookahead_slots +
  737. running_scheduled.num_lookahead_slots +
  738. swapped_in.num_lookahead_slots),
  739. )
  740. def _schedule(self) -> SchedulerOutputs:
  741. """Schedule queued requests."""
  742. if self.scheduler_config.chunked_prefill_enabled:
  743. return self._schedule_chunked_prefill()
  744. else:
  745. return self._schedule_default()
  746. def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
  747. """Determine whether or not we have enough space in the KV cache to
  748. continue generation of the sequence group.
  749. """
  750. # Appending slots only occurs in decoding.
  751. is_prefill = False
  752. return self.block_manager.can_append_slots(
  753. seq_group=seq_group,
  754. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  755. )
  756. def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
  757. # Swapping in is considered decode.
  758. is_prefill = False
  759. return self.block_manager.can_swap_in(
  760. seq_group=seq_group,
  761. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  762. )
  763. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  764. # Schedule sequence groups.
  765. # This function call changes the internal states of the scheduler
  766. # such as self.running, self.swapped, and self.waiting.
  767. scheduler_outputs = self._schedule()
  768. now = time.time()
  769. # Create input data structures.
  770. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  771. for i, scheduled_seq_group in enumerate(
  772. scheduler_outputs.scheduled_seq_groups):
  773. seq_group = scheduled_seq_group.seq_group
  774. token_chunk_size = scheduled_seq_group.token_chunk_size
  775. seq_group.maybe_set_first_scheduled_time(now)
  776. # seq_id -> SequenceData
  777. seq_data: Dict[int, SequenceData] = {}
  778. # seq_id -> physical block numbers
  779. block_tables: Dict[int, List[int]] = {}
  780. # seq_id -> persistent data
  781. persistent_data: Dict[int, dict] = {}
  782. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  783. seq_id = seq.seq_id
  784. seq_data[seq_id] = seq.data
  785. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  786. persistent_data[seq_id] = seq.persistent_data
  787. self.block_manager.access_all_blocks_in_seq(seq, now)
  788. common_computed_block_nums = (
  789. self.block_manager.get_common_computed_block_ids(
  790. seq_group.get_seqs(status=SequenceStatus.RUNNING)))
  791. # It assumes the scheduled_seq_groups is ordered by
  792. # prefill < decoding.
  793. is_prompt = seq_group.is_prefill()
  794. seq_group_metadata = SequenceGroupMetadata(
  795. request_id=seq_group.request_id,
  796. is_prompt=is_prompt,
  797. seq_data=seq_data,
  798. sampling_params=seq_group.sampling_params,
  799. block_tables=block_tables,
  800. token_chunk_size=token_chunk_size,
  801. lora_request=seq_group.lora_request,
  802. persistent_data=persistent_data,
  803. computed_block_nums=common_computed_block_nums,
  804. state=seq_group.state,
  805. # `multi_modal_data` will only be present for the 1st comm
  806. # between engine and worker.
  807. # the subsequent comms can still use delta, but
  808. # `multi_modal_data` will be None.
  809. multi_modal_data=seq_group.multi_modal_data
  810. if scheduler_outputs.num_prefill_groups > 0 else None,
  811. )
  812. seq_group_metadata_list.append(seq_group_metadata)
  813. # Now that the batch has been created, we can assume all blocks in the
  814. # batch will have been computed before the next scheduling invocation.
  815. # This is because the engine assumes that a failure in model execution
  816. # will crash the Aphrodite instance / will not retry.
  817. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
  818. self.block_manager.mark_blocks_as_computed(
  819. scheduled_seq_group.seq_group)
  820. return seq_group_metadata_list, scheduler_outputs
  821. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  822. self.block_manager.fork(parent_seq, child_seq)
  823. def free_seq(self, seq: Sequence) -> None:
  824. """Free a sequence from a block table."""
  825. self.block_manager.free(seq)
  826. def free_finished_seq_groups(self) -> None:
  827. self.running = deque(seq_group for seq_group in self.running
  828. if not seq_group.is_finished())
  829. def _allocate_and_set_running(self, seq_group: SequenceGroup,
  830. num_new_tokens: int) -> None:
  831. self.block_manager.allocate(seq_group)
  832. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  833. seq.status = SequenceStatus.RUNNING
  834. def _append_slots(
  835. self,
  836. seq_group: SequenceGroup,
  837. blocks_to_copy: Dict[int, List[int]],
  838. ) -> None:
  839. """Appends new slots to the sequences in the given sequence group.
  840. Args:
  841. seq_group (SequenceGroup): The sequence group containing the
  842. sequences to append slots to.
  843. blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
  844. block indices to lists of destination block indices. This
  845. dictionary is updated with the new source and destination block
  846. indices for the appended slots.
  847. """
  848. num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
  849. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  850. cows = self.block_manager.append_slots(seq, num_lookahead_slots)
  851. for src, dests in cows.items():
  852. if src not in blocks_to_copy:
  853. blocks_to_copy[src] = []
  854. blocks_to_copy[src].extend(dests)
  855. def _preempt(
  856. self,
  857. seq_group: SequenceGroup,
  858. blocks_to_swap_out: Dict[int, int],
  859. preemption_mode: Optional[PreemptionMode] = None,
  860. ) -> PreemptionMode:
  861. # If preemption mode is not specified, we determine the mode as follows:
  862. # We use recomputation by default since it incurs lower overhead than
  863. # swapping. However, when the sequence group has multiple sequences
  864. # (e.g., beam search), recomputation is not currently supported. In
  865. # such a case, we use swapping instead.
  866. # FIXME: This makes our scheduling policy a bit bizarre.
  867. # As swapped sequences are prioritized over waiting sequences,
  868. # sequence groups with multiple sequences are implicitly prioritized
  869. # over sequence groups with a single sequence.
  870. # TODO: Support recomputation for sequence groups with multiple
  871. # sequences. This may require a more sophisticated CUDA kernel.
  872. if preemption_mode is None:
  873. if seq_group.get_max_num_running_seqs() == 1:
  874. preemption_mode = PreemptionMode.RECOMPUTE
  875. else:
  876. preemption_mode = PreemptionMode.SWAP
  877. if preemption_mode == PreemptionMode.RECOMPUTE:
  878. self._preempt_by_recompute(seq_group)
  879. elif preemption_mode == PreemptionMode.SWAP:
  880. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  881. else:
  882. raise AssertionError("Invalid preemption mode.")
  883. return preemption_mode
  884. def _preempt_by_recompute(
  885. self,
  886. seq_group: SequenceGroup,
  887. ) -> None:
  888. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  889. assert len(seqs) == 1
  890. for seq in seqs:
  891. seq.status = SequenceStatus.WAITING
  892. self.free_seq(seq)
  893. seq.reset_state_for_recompute()
  894. def _preempt_by_swap(
  895. self,
  896. seq_group: SequenceGroup,
  897. blocks_to_swap_out: Dict[int, int],
  898. ) -> None:
  899. self._swap_out(seq_group, blocks_to_swap_out)
  900. def _swap_in(
  901. self,
  902. seq_group: SequenceGroup,
  903. blocks_to_swap_in: Dict[int, int],
  904. ) -> None:
  905. mapping = self.block_manager.swap_in(seq_group)
  906. blocks_to_swap_in.update(mapping)
  907. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  908. seq.status = SequenceStatus.RUNNING
  909. def _swap_out(
  910. self,
  911. seq_group: SequenceGroup,
  912. blocks_to_swap_out: Dict[int, int],
  913. ) -> None:
  914. if not self.block_manager.can_swap_out(seq_group):
  915. # FIXME: Abort the sequence group instead of aborting the
  916. # entire engine.
  917. raise RuntimeError(
  918. "Aborted due to the lack of CPU swap space. Please increase "
  919. "the swap space to avoid this error.")
  920. mapping = self.block_manager.swap_out(seq_group)
  921. blocks_to_swap_out.update(mapping)
  922. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  923. seq.status = SequenceStatus.SWAPPED
  924. def _passed_delay(self, now: float) -> bool:
  925. if self.prev_prompt:
  926. self.last_prompt_latency = now - self.prev_time
  927. self.prev_time, self.prev_prompt = now, False
  928. # Delay scheduling prompts to let waiting queue fill up
  929. if self.scheduler_config.delay_factor > 0 and self.waiting:
  930. earliest_arrival_time = min(
  931. [e.metrics.arrival_time for e in self.waiting])
  932. passed_delay = (
  933. (now - earliest_arrival_time) >
  934. (self.scheduler_config.delay_factor * self.last_prompt_latency)
  935. or not self.running)
  936. else:
  937. passed_delay = True
  938. return passed_delay
  939. def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
  940. """The number of slots to allocate per sequence per step, beyond known
  941. token ids. Speculative decoding uses these slots to store KV activations
  942. of tokens which may or may not be accepted.
  943. Speculative decoding does not yet support prefill, so we do not perform
  944. lookahead allocation for prefill.
  945. """
  946. if is_prefill:
  947. return 0
  948. return self.scheduler_config.num_lookahead_slots
  949. def _get_num_new_tokens(self, seq_group: SequenceGroup,
  950. status: SequenceStatus, enable_chunking: bool,
  951. budget: SchedulingBudget) -> Tuple[int, bool]:
  952. """Get the next new tokens to compute for a given sequence group
  953. that's in a given `status`.
  954. The API could chunk the number of tokens to compute based on `budget`
  955. if `enable_chunking` is True. If a sequence group has multiple
  956. sequences (e.g., running beam search), it means it is in decoding
  957. phase, so chunking doesn't happen.
  958. """
  959. num_new_tokens = 0
  960. seqs = seq_group.get_seqs(status=status)
  961. for seq in seqs:
  962. num_new_tokens += seq.get_num_new_tokens()
  963. # Chunk if a running request cannot fit in.
  964. # If number of seq > 1, it means it is doing beam search in a
  965. # decode phase. Do not chunk in that case.
  966. if enable_chunking and len(seqs) == 1:
  967. num_new_tokens = min(num_new_tokens,
  968. budget.remaining_token_budget())
  969. return num_new_tokens