scheduler.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463
  1. import enum
  2. import os
  3. import random
  4. import time
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
  8. Tuple, Union)
  9. from loguru import logger
  10. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  11. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  12. SequenceGroupMetadata,
  13. SequenceGroupMetadataDelta,
  14. SequenceStatus)
  15. from aphrodite.common.utils import Device, PyObjectCache
  16. from aphrodite.lora.request import LoRARequest
  17. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  18. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  19. # Test-only. If configured, decode is preempted with
  20. # ARTIFICIAL_PREEMPTION_PROB% probability.
  21. ENABLE_ARTIFICIAL_PREEMPT = bool(
  22. os.getenv("APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
  23. ARTIFICIAL_PREEMPTION_PROB = 0.5
  24. ARTIFICIAL_PREEMPTION_MAX_CNT = 500
  25. class PreemptionMode(enum.Enum):
  26. """Preemption modes.
  27. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  28. and swap them back in when the sequences are resumed.
  29. 2. Recomputation: Discard the blocks of the preempted sequences and
  30. recompute them when the sequences are resumed, treating the sequences as
  31. new prompts.
  32. """
  33. SWAP = enum.auto()
  34. RECOMPUTE = enum.auto()
  35. @dataclass
  36. class SchedulingBudget:
  37. """The available slots for scheduling.
  38. TODO: Right now, the budget is request_id-aware meaning it can ignore
  39. budget update from the same request_id. It is because in normal scheduling
  40. path, we update RUNNING num_seqs ahead of time, meaning it could be
  41. updated more than once when scheduling RUNNING requests. Since this won't
  42. happen if we only have chunked prefill scheduling, we can remove this
  43. feature from the API when chunked prefill is enabled by default.
  44. """
  45. token_budget: int
  46. max_num_seqs: int
  47. _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
  48. _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
  49. _num_batched_tokens: int = 0
  50. _num_curr_seqs: int = 0
  51. def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
  52. assert num_new_tokens != 0
  53. assert num_new_seqs != 0
  54. return (self.num_batched_tokens + num_new_tokens <= self.token_budget
  55. and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
  56. def remaining_token_budget(self):
  57. return self.token_budget - self.num_batched_tokens
  58. def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
  59. if req_id in self._request_ids_num_batched_tokens:
  60. return
  61. self._request_ids_num_batched_tokens.add(req_id)
  62. self._num_batched_tokens += num_batched_tokens
  63. def subtract_num_batched_tokens(self, req_id: str,
  64. num_batched_tokens: int):
  65. if req_id in self._request_ids_num_batched_tokens:
  66. self._request_ids_num_batched_tokens.remove(req_id)
  67. self._num_batched_tokens -= num_batched_tokens
  68. def add_num_seqs(self, req_id: str, num_curr_seqs: int):
  69. if req_id in self._request_ids_num_curr_seqs:
  70. return
  71. self._request_ids_num_curr_seqs.add(req_id)
  72. self._num_curr_seqs += num_curr_seqs
  73. def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
  74. if req_id in self._request_ids_num_curr_seqs:
  75. self._request_ids_num_curr_seqs.remove(req_id)
  76. self._num_curr_seqs -= num_curr_seqs
  77. @property
  78. def num_batched_tokens(self):
  79. return self._num_batched_tokens
  80. @property
  81. def num_curr_seqs(self):
  82. return self._num_curr_seqs
  83. @dataclass
  84. class ScheduledSequenceGroup:
  85. # A sequence group that's scheduled.
  86. seq_group: SequenceGroup
  87. # The total chunk size (number of tokens) to process for next iteration.
  88. # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
  89. # chunked, it can be smaller than that.
  90. token_chunk_size: int
  91. @dataclass
  92. class SchedulerOutputs:
  93. """The scheduling decision made from a scheduler."""
  94. # Scheduled sequence groups.
  95. scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
  96. # Number of prefill groups scheduled.
  97. num_prefill_groups: int
  98. # Total number of batched tokens.
  99. num_batched_tokens: int
  100. # Blocks to swap in. List of CPU -> GPU block number.
  101. blocks_to_swap_in: List[Tuple[int, int]]
  102. # Blocks to swap out. List of GPU -> CPU block number.
  103. blocks_to_swap_out: List[Tuple[int, int]]
  104. # Blocks to copy. Source to dest block.
  105. blocks_to_copy: List[Tuple[int, int]]
  106. # Sequence groups that are going to be ignored.
  107. ignored_seq_groups: List[SequenceGroup]
  108. # The number of slots for lookahead decoding.
  109. num_lookahead_slots: int
  110. # The number of requests in the running queue
  111. running_queue_size: int
  112. preempted: int
  113. def __post_init__(self):
  114. # Swap in and swap out should never happen at the same time.
  115. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
  116. self.num_loras: int = len(self.lora_requests)
  117. if self.num_loras > 0:
  118. self._sort_by_lora_ids()
  119. self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
  120. def is_empty(self) -> bool:
  121. # NOTE: We do not consider the ignored sequence groups.
  122. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  123. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  124. def _sort_by_lora_ids(self):
  125. self.scheduled_seq_groups = sorted(
  126. self.scheduled_seq_groups,
  127. key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
  128. @property
  129. def lora_requests(self) -> Set[LoRARequest]:
  130. return {
  131. g.seq_group.lora_request
  132. for g in self.scheduled_seq_groups
  133. if g.seq_group.lora_request is not None
  134. }
  135. @property
  136. def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
  137. return {
  138. g.seq_group.prompt_adapter_request
  139. for g in self.scheduled_seq_groups
  140. if g.seq_group.prompt_adapter_request is not None
  141. }
  142. @dataclass
  143. class SchedulerRunningOutputs:
  144. """The requests that are scheduled from a running queue.
  145. Could contain prefill (prefill that's chunked) or decodes. If there's not
  146. enough memory, it can be preempted (for recompute) or swapped out.
  147. """
  148. # Selected sequences that are running and in a decoding phase.
  149. decode_seq_groups: List[ScheduledSequenceGroup]
  150. # Selected sequences that are running and in a prefill phase.
  151. # I.e., it means the prefill has been chunked.
  152. prefill_seq_groups: List[ScheduledSequenceGroup]
  153. # The preempted sequences.
  154. preempted: List[SequenceGroup]
  155. # Sequences that are swapped out.
  156. swapped_out: List[SequenceGroup]
  157. # The blocks to swap out.
  158. blocks_to_swap_out: List[Tuple[int, int]]
  159. # The blocks to copy.
  160. blocks_to_copy: List[Tuple[int, int]]
  161. # The number of slots for lookahead decoding.
  162. num_lookahead_slots: int
  163. # Optimization for fast-access to seq_group lists
  164. decode_seq_groups_list: List[SequenceGroup]
  165. prefill_seq_groups_list: List[SequenceGroup]
  166. @classmethod
  167. def create_empty(cls) -> "SchedulerRunningOutputs":
  168. return SchedulerRunningOutputs(
  169. decode_seq_groups=[],
  170. prefill_seq_groups=[],
  171. preempted=[],
  172. swapped_out=[],
  173. blocks_to_swap_out=[],
  174. blocks_to_copy=[],
  175. num_lookahead_slots=0,
  176. decode_seq_groups_list=[],
  177. prefill_seq_groups_list=[],
  178. )
  179. @dataclass
  180. class SchedulerSwappedInOutputs:
  181. """The requests that are scheduled from a swap queue.
  182. Could contain prefill (prefill that's chunked) or decodes.
  183. """
  184. # Selected sequences that are going to be swapped in and is in a
  185. # decoding phase.
  186. decode_seq_groups: List[SequenceGroup]
  187. # Selected sequences that are going to be swapped in and in a prefill
  188. # phase. I.e., it means the prefill has been chunked.
  189. prefill_seq_groups: List[SequenceGroup]
  190. # The blocks to swap in.
  191. blocks_to_swap_in: List[Tuple[int, int]]
  192. # The blocks to copy.
  193. blocks_to_copy: List[Tuple[int, int]]
  194. # The number of slots for lookahead decoding.
  195. num_lookahead_slots: int
  196. # Infeasible sequence groups.
  197. infeasible_seq_groups: List[SequenceGroup]
  198. @classmethod
  199. def create_empty(cls) -> "SchedulerSwappedInOutputs":
  200. return SchedulerSwappedInOutputs(
  201. decode_seq_groups=[],
  202. prefill_seq_groups=[],
  203. blocks_to_swap_in=[],
  204. blocks_to_copy=[],
  205. num_lookahead_slots=0,
  206. infeasible_seq_groups=[],
  207. )
  208. @dataclass
  209. class SchedulerPrefillOutputs:
  210. """The requests that are scheduled from a waiting queue.
  211. Could contain a fresh prefill requests or preempted requests that need
  212. to be recomputed from scratch.
  213. """
  214. # Selected sequences for prefill.
  215. seq_groups: List[SequenceGroup]
  216. # Ignored sequence groups.
  217. ignored_seq_groups: List[SequenceGroup]
  218. num_lookahead_slots: int
  219. @classmethod
  220. def create_empty(cls) -> "SchedulerPrefillOutputs":
  221. return SchedulerPrefillOutputs(
  222. seq_groups=[],
  223. ignored_seq_groups=[],
  224. num_lookahead_slots=0,
  225. )
  226. def seq_group_metadata_builder():
  227. return SequenceGroupMetadata(request_id="",
  228. is_prompt=False,
  229. seq_data={},
  230. sampling_params=None,
  231. block_tables={})
  232. def scheduler_running_outputs_builder():
  233. return SchedulerRunningOutputs(decode_seq_groups=[],
  234. prefill_seq_groups=[],
  235. preempted=[],
  236. swapped_out=[],
  237. blocks_to_swap_out=[],
  238. blocks_to_copy=[],
  239. num_lookahead_slots=0,
  240. prefill_seq_groups_list=[],
  241. decode_seq_groups_list=[])
  242. def scheduled_seq_group_builder():
  243. return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
  244. class Scheduler:
  245. def __init__(
  246. self,
  247. scheduler_config: SchedulerConfig,
  248. cache_config: CacheConfig,
  249. lora_config: Optional[LoRAConfig],
  250. pipeline_parallel_size: int = 1,
  251. output_proc_callback: Optional[Callable] = None,
  252. ) -> None:
  253. self.scheduler_config = scheduler_config
  254. self.cache_config = cache_config
  255. # Note for LoRA scheduling: the current policy is extremely
  256. # simple and NOT fair. It can lead to starvation of some
  257. # LoRAs. This should be improved in the future.
  258. self.lora_config = lora_config
  259. version = "v1"
  260. if self.scheduler_config.use_v2_block_manager:
  261. version = "v2"
  262. if (self.scheduler_config.embedding_mode
  263. or self.scheduler_config.is_attention_free):
  264. version = "placeholder"
  265. BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
  266. version)
  267. num_gpu_blocks = cache_config.num_gpu_blocks
  268. if num_gpu_blocks:
  269. num_gpu_blocks //= pipeline_parallel_size
  270. num_cpu_blocks = cache_config.num_cpu_blocks
  271. if num_cpu_blocks:
  272. num_cpu_blocks //= pipeline_parallel_size
  273. # Create the block space manager.
  274. self.block_manager = BlockSpaceManagerImpl(
  275. block_size=self.cache_config.block_size,
  276. num_gpu_blocks=num_gpu_blocks,
  277. num_cpu_blocks=num_cpu_blocks,
  278. sliding_window=self.cache_config.sliding_window,
  279. enable_caching=self.cache_config.enable_prefix_caching)
  280. # Sequence groups in the WAITING state.
  281. # Contain new prefill or preempted requests.
  282. self.waiting: Deque[SequenceGroup] = deque()
  283. # Sequence groups in the RUNNING state.
  284. # Contain decode requests.
  285. self.running: Deque[SequenceGroup] = deque()
  286. # Sequence groups in the SWAPPED state.
  287. # Contain decode requests that are swapped out.
  288. self.swapped: Deque[SequenceGroup] = deque()
  289. # Sequence groups finished requests ids since last step iteration.
  290. # It lets the model know that any state associated with these requests
  291. # can and must be released after the current step.
  292. # This is used to evict the finished requests from the Mamba cache.
  293. self._finished_requests_ids: List[str] = list()
  294. # Time at previous scheduling step
  295. self.prev_time = 0.0
  296. # Did we schedule a prompt at previous step?
  297. self.prev_prompt = False
  298. # Latency of the last prompt step
  299. self.last_prompt_latency = 0.0
  300. # preemption mode, RECOMPUTE or SWAP
  301. self.user_specified_preemption_mode = scheduler_config.preemption_mode
  302. # The following field is test-only. It is used to inject artificial
  303. # preemption.
  304. self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
  305. self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
  306. if self.enable_artificial_preemption
  307. else 0)
  308. self.num_cumulative_preemption: int = 0
  309. # Used to cache python objects
  310. self._seq_group_metadata_cache: List[PyObjectCache] = []
  311. self._scheduler_running_outputs_cache: List[PyObjectCache] = []
  312. self._scheduled_seq_group_cache: List[PyObjectCache] = []
  313. # For async output processing, we need to swap cache buffers between
  314. # iterations. I.e. since the output processing is lagged one step,
  315. # we cannot reuse the cached objects immediately when the schedule()
  316. # is called again, but only when schedule() is called the second time.
  317. self.output_proc_callback = output_proc_callback
  318. self.use_async_output_proc = self.output_proc_callback is not None
  319. self.num_cache_iters = 2 if self.use_async_output_proc else 1
  320. self.cache_id = 0
  321. for i in range(self.num_cache_iters):
  322. self._seq_group_metadata_cache.append(
  323. PyObjectCache(seq_group_metadata_builder))
  324. self._scheduler_running_outputs_cache.append(
  325. PyObjectCache(scheduler_running_outputs_builder))
  326. self._scheduled_seq_group_cache.append(
  327. PyObjectCache(scheduled_seq_group_builder))
  328. # For async postprocessor, the extra decode run cannot be done
  329. # when the request reaches max_model_len. In this case, the request
  330. # will be stopped during schedule() call and added to this stop list
  331. # for processing and deallocation by the free_finished_seq_groups()
  332. self._async_stopped: List[SequenceGroup] = []
  333. @property
  334. def next_cache_id(self):
  335. return (self.cache_id + 1) % self.num_cache_iters
  336. @property
  337. def lora_enabled(self) -> bool:
  338. return bool(self.lora_config)
  339. @property
  340. def num_decoding_tokens_per_seq(self) -> int:
  341. """The number of new tokens."""
  342. return 1
  343. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  344. # Add sequence groups to the waiting queue.
  345. self.waiting.append(seq_group)
  346. def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
  347. # Add sequence groups to the running queue.
  348. # Only for testing purposes.
  349. self.running.append(seq_group)
  350. def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
  351. # Add sequence groups to the swapped queue.
  352. # Only for testing purposes.
  353. self.swapped.append(seq_group)
  354. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  355. """Aborts a sequence group with the given ID.
  356. Check if the sequence group with the given ID
  357. is present in any of the state queue.
  358. If present, remove the sequence group from the state queue.
  359. Also, if any of the sequences in the sequence group is not finished,
  360. free the sequence with status `FINISHED_ABORTED`.
  361. Otherwise, do nothing.
  362. Args:
  363. request_id: The ID(s) of the sequence group to abort.
  364. """
  365. if isinstance(request_id, str):
  366. request_id = (request_id, )
  367. request_ids = set(request_id)
  368. for state_queue in [self.waiting, self.running, self.swapped]:
  369. aborted_groups: List[SequenceGroup] = []
  370. for seq_group in state_queue:
  371. if not request_ids:
  372. # Using 'break' here may add two extra iterations,
  373. # but is acceptable to reduce complexity.
  374. break
  375. if seq_group.request_id in request_ids:
  376. # Appending aborted group into pending list.
  377. aborted_groups.append(seq_group)
  378. request_ids.remove(seq_group.request_id)
  379. for aborted_group in aborted_groups:
  380. # Remove the sequence group from the state queue.
  381. state_queue.remove(aborted_group)
  382. # Remove the aborted request from the Mamba cache.
  383. self._finished_requests_ids.append(aborted_group.request_id)
  384. for seq in aborted_group.get_seqs():
  385. if seq.is_finished():
  386. continue
  387. seq.status = SequenceStatus.FINISHED_ABORTED
  388. self.free_seq(seq)
  389. self._free_seq_group_cross_attn_blocks(aborted_group)
  390. def _free_seq_group_cross_attn_blocks(
  391. self,
  392. seq_group: SequenceGroup,
  393. ) -> None:
  394. """
  395. Free a sequence group from a cross-attention block table.
  396. Has no effect on decoder-only models.
  397. """
  398. if seq_group.is_encoder_decoder():
  399. self.block_manager.free_cross(seq_group)
  400. def has_unfinished_seqs(self) -> bool:
  401. return len(self.waiting) != 0 or len(self.running) != 0 or len(
  402. self.swapped) != 0
  403. def get_prefix_cache_hit_rate(self, device: Device) -> float:
  404. return self.block_manager.get_prefix_cache_hit_rate(device)
  405. def get_num_unfinished_seq_groups(self) -> int:
  406. return len(self.waiting) + len(self.running) + len(self.swapped)
  407. def get_and_reset_finished_requests_ids(self) -> List[str]:
  408. """Flushes the list of request ids of previously finished seq_groups."""
  409. finished_requests_ids = self._finished_requests_ids
  410. self._finished_requests_ids = list()
  411. return finished_requests_ids
  412. def _schedule_running(
  413. self,
  414. budget: SchedulingBudget,
  415. curr_loras: Optional[Set[int]],
  416. enable_chunking: bool = False,
  417. ) -> SchedulerRunningOutputs:
  418. """Schedule sequence groups that are running.
  419. Running queue should include decode and chunked prefill requests.
  420. Args:
  421. budget: The scheduling budget. The argument is in-place updated
  422. when any decodes are preempted.
  423. curr_loras: Currently batched lora request ids. The argument is
  424. in-place updated when any decodes are preempted.
  425. enable_chunking: If True, seq group can be chunked and only a
  426. chunked number of tokens are scheduled if
  427. `budget.num_batched_tokens` has not enough capacity to schedule
  428. all tokens.
  429. Returns:
  430. SchedulerRunningOutputs.
  431. """
  432. ret: SchedulerRunningOutputs = \
  433. self._scheduler_running_outputs_cache[self.cache_id].get_object()
  434. ret.blocks_to_swap_out.clear()
  435. ret.blocks_to_copy.clear()
  436. ret.decode_seq_groups.clear()
  437. ret.prefill_seq_groups.clear()
  438. ret.preempted.clear()
  439. ret.swapped_out.clear()
  440. ret.num_lookahead_slots = self._get_num_lookahead_slots(
  441. is_prefill=False)
  442. ret.decode_seq_groups_list.clear()
  443. ret.prefill_seq_groups_list.clear()
  444. # Blocks that need to be swapped or copied before model execution.
  445. blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
  446. blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
  447. decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
  448. prefill_seq_groups: List[
  449. ScheduledSequenceGroup] = ret.prefill_seq_groups
  450. preempted: List[SequenceGroup] = ret.preempted
  451. swapped_out: List[SequenceGroup] = ret.swapped_out
  452. # NOTE: Preemption happens only when there is no available slot
  453. # to keep all the sequence groups in the RUNNING state.
  454. # Store original running requests for the case of async + preemption
  455. if self.use_async_output_proc:
  456. orig_running = self.running.copy()
  457. running_queue = self.running
  458. assert len(self._async_stopped) == 0
  459. while running_queue:
  460. seq_group = running_queue[0]
  461. num_running_tokens = self._get_num_new_tokens(
  462. seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
  463. if num_running_tokens == 0:
  464. break
  465. running_queue.popleft()
  466. # With async postprocessor, an extra decode run is done
  467. # to process the final tokens. The check below avoids this extra
  468. # decode run when the model max len is reached, in order to avoid
  469. # a memory overflow.
  470. if self.use_async_output_proc and seq_group.seqs[0].get_len(
  471. ) > self.scheduler_config.max_model_len:
  472. self._async_stopped.append(seq_group)
  473. continue
  474. # With async postprocessor, when preemption kicks in, we need
  475. # first to drain the async postprocessor, so that all async
  476. # block_table freeing is applied before the preemption freeing
  477. # is applied.
  478. if self.use_async_output_proc and not self._can_append_slots(
  479. seq_group):
  480. tmp = self.running
  481. self.running = orig_running
  482. assert self.output_proc_callback is not None
  483. self.output_proc_callback()
  484. self.running = tmp
  485. while not self._can_append_slots(seq_group):
  486. budget.subtract_num_batched_tokens(seq_group.request_id,
  487. num_running_tokens)
  488. num_running_seqs = seq_group.get_max_num_running_seqs()
  489. budget.subtract_num_seqs(seq_group.request_id,
  490. num_running_seqs)
  491. if (curr_loras is not None and seq_group.lora_int_id > 0
  492. and seq_group.lora_int_id in curr_loras):
  493. curr_loras.remove(seq_group.lora_int_id)
  494. if running_queue:
  495. # Preempt the lowest-priority sequence groups.
  496. victim_seq_group = running_queue.pop()
  497. preempted_mode = self._preempt(victim_seq_group,
  498. blocks_to_swap_out)
  499. if preempted_mode == PreemptionMode.RECOMPUTE:
  500. preempted.append(victim_seq_group)
  501. else:
  502. swapped_out.append(victim_seq_group)
  503. else:
  504. # No other sequence groups can be preempted.
  505. # Preempt the current sequence group.
  506. preempted_mode = self._preempt(seq_group,
  507. blocks_to_swap_out)
  508. if preempted_mode == PreemptionMode.RECOMPUTE:
  509. preempted.append(seq_group)
  510. else:
  511. swapped_out.append(seq_group)
  512. break
  513. else:
  514. self._append_slots(seq_group, blocks_to_copy)
  515. is_prefill = seq_group.is_prefill()
  516. scheduled_seq_group: ScheduledSequenceGroup = \
  517. self._scheduled_seq_group_cache[self.cache_id].get_object()
  518. scheduled_seq_group.seq_group = seq_group
  519. if is_prefill:
  520. scheduled_seq_group.token_chunk_size = num_running_tokens
  521. prefill_seq_groups.append(scheduled_seq_group)
  522. ret.prefill_seq_groups_list.append(seq_group)
  523. else:
  524. scheduled_seq_group.token_chunk_size = 1
  525. decode_seq_groups.append(scheduled_seq_group)
  526. ret.decode_seq_groups_list.append(seq_group)
  527. budget.add_num_batched_tokens(seq_group.request_id,
  528. num_running_tokens)
  529. # OPTIMIZATION: Note that get_max_num_running_seqs is
  530. # expensive. For the default scheduling chase where
  531. # enable_chunking is False, num_seqs are updated before running
  532. # this method, so we don't have to update it again here.
  533. if enable_chunking:
  534. num_running_seqs = seq_group.get_max_num_running_seqs()
  535. budget.add_num_seqs(seq_group.request_id, num_running_seqs)
  536. if curr_loras is not None and seq_group.lora_int_id > 0:
  537. curr_loras.add(seq_group.lora_int_id)
  538. self._scheduler_running_outputs_cache[self.next_cache_id].reset()
  539. self._scheduled_seq_group_cache[self.next_cache_id].reset()
  540. return ret
  541. def _schedule_swapped(
  542. self,
  543. budget: SchedulingBudget,
  544. curr_loras: Optional[Set[int]],
  545. enable_chunking: bool = False,
  546. ) -> SchedulerSwappedInOutputs:
  547. """Schedule sequence groups that are swapped out.
  548. It schedules swapped requests as long as it fits `budget` and
  549. curr_loras <= max_lora from the scheduling config. The input arguments
  550. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  551. Args:
  552. budget: The scheduling budget. The argument is in-place updated
  553. when any requests are swapped in.
  554. curr_loras: Currently batched lora request ids. The argument is
  555. in-place updated when any requests are swapped in.
  556. enable_chunking: If True, seq group can be chunked and only a
  557. chunked number of tokens are scheduled if
  558. `budget.num_batched_tokens` has not enough capacity to schedule
  559. all tokens.
  560. Returns:
  561. SchedulerSwappedInOutputs.
  562. """
  563. # Blocks that need to be swapped or copied before model execution.
  564. blocks_to_swap_in: List[Tuple[int, int]] = []
  565. blocks_to_copy: List[Tuple[int, int]] = []
  566. decode_seq_groups: List[ScheduledSequenceGroup] = []
  567. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  568. infeasible_seq_groups: List[SequenceGroup] = []
  569. swapped_queue = self.swapped
  570. leftover_swapped: Deque[SequenceGroup] = deque()
  571. while swapped_queue:
  572. seq_group = swapped_queue[0]
  573. # If the sequence group cannot be swapped in, stop.
  574. is_prefill = seq_group.is_prefill()
  575. alloc_status = self.block_manager.can_swap_in(
  576. seq_group, self._get_num_lookahead_slots(is_prefill))
  577. if alloc_status == AllocStatus.LATER:
  578. break
  579. elif alloc_status == AllocStatus.NEVER:
  580. logger.warning(f"Failing the request {seq_group.request_id} "
  581. "because there's not enough kv cache blocks to "
  582. "run the entire sequence.")
  583. for seq in seq_group.get_seqs():
  584. seq.status = SequenceStatus.FINISHED_IGNORED
  585. infeasible_seq_groups.append(seq_group)
  586. swapped_queue.popleft()
  587. continue
  588. lora_int_id = 0
  589. if self.lora_enabled:
  590. lora_int_id = seq_group.lora_int_id
  591. assert curr_loras is not None
  592. assert self.lora_config is not None
  593. if (lora_int_id > 0 and (lora_int_id not in curr_loras)
  594. and len(curr_loras) >= self.lora_config.max_loras):
  595. # We don't have a space for another LoRA, so
  596. # we ignore this request for now.
  597. leftover_swapped.appendleft(seq_group)
  598. swapped_queue.popleft()
  599. continue
  600. # The total number of sequences in the RUNNING state should not
  601. # exceed the maximum number of sequences.
  602. num_new_seqs = seq_group.get_max_num_running_seqs()
  603. num_new_tokens = self._get_num_new_tokens(seq_group,
  604. SequenceStatus.SWAPPED,
  605. enable_chunking, budget)
  606. if (num_new_tokens == 0
  607. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  608. num_new_seqs=num_new_seqs)):
  609. break
  610. if lora_int_id > 0 and curr_loras is not None:
  611. curr_loras.add(lora_int_id)
  612. swapped_queue.popleft()
  613. self._swap_in(seq_group, blocks_to_swap_in)
  614. self._append_slots(seq_group, blocks_to_copy)
  615. is_prefill = seq_group.is_prefill()
  616. if is_prefill:
  617. prefill_seq_groups.append(
  618. ScheduledSequenceGroup(seq_group,
  619. token_chunk_size=num_new_tokens))
  620. else:
  621. decode_seq_groups.append(
  622. ScheduledSequenceGroup(seq_group, token_chunk_size=1))
  623. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  624. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  625. swapped_queue.extendleft(leftover_swapped)
  626. return SchedulerSwappedInOutputs(
  627. decode_seq_groups=decode_seq_groups,
  628. prefill_seq_groups=prefill_seq_groups,
  629. blocks_to_swap_in=blocks_to_swap_in,
  630. blocks_to_copy=blocks_to_copy,
  631. num_lookahead_slots=self._get_num_lookahead_slots(
  632. is_prefill=False),
  633. infeasible_seq_groups=infeasible_seq_groups,
  634. )
  635. def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
  636. if self.scheduler_config.chunked_prefill_enabled:
  637. prompt_limit = self.scheduler_config.max_model_len
  638. else:
  639. prompt_limit = min(self.scheduler_config.max_model_len,
  640. self.scheduler_config.max_num_batched_tokens)
  641. # Model is fine tuned with long context. Return the fine tuned max_len.
  642. if (seq_group.lora_request
  643. and seq_group.lora_request.long_lora_max_len):
  644. assert prompt_limit <= seq_group.lora_request.long_lora_max_len
  645. return seq_group.lora_request.long_lora_max_len
  646. else:
  647. return prompt_limit
  648. def _schedule_prefills(
  649. self,
  650. budget: SchedulingBudget,
  651. curr_loras: Optional[Set[int]],
  652. enable_chunking: bool = False,
  653. ) -> SchedulerPrefillOutputs:
  654. """Schedule sequence groups that are in prefill stage.
  655. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
  656. as a new prefill (that starts from beginning -> most recently generated
  657. tokens).
  658. It schedules waiting requests as long as it fits `budget` and
  659. curr_loras <= max_lora from the scheduling config. The input arguments
  660. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  661. Args:
  662. budget: The scheduling budget. The argument is in-place updated
  663. when any requests are scheduled.
  664. curr_loras: Currently batched lora request ids. The argument is
  665. in-place updated when any requests are scheduled.
  666. enable_chunking: If True, seq group can be chunked and only a
  667. chunked number of tokens are scheduled if
  668. `budget.num_batched_tokens` has not enough capacity to schedule
  669. all tokens.
  670. Returns:
  671. SchedulerPrefillOutputs.
  672. """
  673. ignored_seq_groups: List[SequenceGroup] = []
  674. seq_groups: List[SequenceGroup] = []
  675. waiting_queue = self.waiting
  676. leftover_waiting_sequences: Deque[SequenceGroup] = deque()
  677. while self._passed_delay(time.time()) and waiting_queue:
  678. seq_group = waiting_queue[0]
  679. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  680. assert len(waiting_seqs) == 1, (
  681. "Waiting sequence group should have only one prompt "
  682. "sequence.")
  683. num_new_tokens = self._get_num_new_tokens(seq_group,
  684. SequenceStatus.WAITING,
  685. enable_chunking, budget)
  686. if not enable_chunking:
  687. num_prompt_tokens = waiting_seqs[0].get_len()
  688. assert num_new_tokens == num_prompt_tokens
  689. prompt_limit = self._get_prompt_limit(seq_group)
  690. if num_new_tokens > prompt_limit:
  691. logger.warning(f"Input prompt ({num_new_tokens} tokens) is "
  692. f"too long and exceeds limit of {prompt_limit}")
  693. for seq in waiting_seqs:
  694. seq.status = SequenceStatus.FINISHED_IGNORED
  695. ignored_seq_groups.append(seq_group)
  696. waiting_queue.popleft()
  697. continue
  698. # If the sequence group cannot be allocated, stop.
  699. can_allocate = self.block_manager.can_allocate(seq_group)
  700. if can_allocate == AllocStatus.LATER:
  701. break
  702. elif can_allocate == AllocStatus.NEVER:
  703. logger.warning(f"Input prompt ({num_new_tokens} tokens) is "
  704. "too long and exceeds the capacity of "
  705. "block_manager")
  706. for seq in waiting_seqs:
  707. seq.status = SequenceStatus.FINISHED_IGNORED
  708. ignored_seq_groups.append(seq_group)
  709. waiting_queue.popleft()
  710. continue
  711. lora_int_id = 0
  712. if self.lora_enabled:
  713. lora_int_id = seq_group.lora_int_id
  714. assert curr_loras is not None
  715. assert self.lora_config is not None
  716. if (self.lora_enabled and lora_int_id > 0
  717. and lora_int_id not in curr_loras
  718. and len(curr_loras) >= self.lora_config.max_loras):
  719. # We don't have a space for another LoRA, so
  720. # we ignore this request for now.
  721. leftover_waiting_sequences.appendleft(seq_group)
  722. waiting_queue.popleft()
  723. continue
  724. num_new_seqs = seq_group.get_max_num_running_seqs()
  725. if (num_new_tokens == 0
  726. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  727. num_new_seqs=num_new_seqs)):
  728. break
  729. # Can schedule this request.
  730. if curr_loras is not None and lora_int_id > 0:
  731. curr_loras.add(lora_int_id)
  732. waiting_queue.popleft()
  733. self._allocate_and_set_running(seq_group)
  734. seq_group.init_multi_step(
  735. num_scheduler_steps=self._get_num_lookahead_slots(
  736. is_prefill=True) + 1)
  737. seq_groups.append(
  738. ScheduledSequenceGroup(seq_group=seq_group,
  739. token_chunk_size=num_new_tokens))
  740. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  741. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  742. # Queue requests that couldn't be scheduled.
  743. waiting_queue.extendleft(leftover_waiting_sequences)
  744. if len(seq_groups) > 0:
  745. self.prev_prompt = True
  746. return SchedulerPrefillOutputs(
  747. seq_groups=seq_groups,
  748. ignored_seq_groups=ignored_seq_groups,
  749. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
  750. def _schedule_default(self) -> SchedulerOutputs:
  751. """Schedule queued requests.
  752. The current policy is designed to optimize the throughput. First,
  753. it batches as many prefill requests as possible. And it schedules
  754. decodes. If there's a pressure on GPU memory, decode requests can
  755. be swapped or preempted.
  756. """
  757. # Include running requests to the budget.
  758. budget = SchedulingBudget(
  759. token_budget=self.scheduler_config.max_num_batched_tokens,
  760. max_num_seqs=self.scheduler_config.max_num_seqs,
  761. )
  762. # Make sure we include num running seqs before scheduling prefill,
  763. # so that we don't schedule beyond max_num_seqs for prefill.
  764. for seq_group in self.running:
  765. budget.add_num_seqs(seq_group.request_id,
  766. seq_group.get_max_num_running_seqs())
  767. curr_loras = set(
  768. seq_group.lora_int_id for seq_group in self.running
  769. if seq_group.lora_int_id > 0) if self.lora_enabled else None
  770. prefills = SchedulerPrefillOutputs.create_empty()
  771. running_scheduled = SchedulerRunningOutputs.create_empty()
  772. swapped_in = SchedulerSwappedInOutputs.create_empty()
  773. # If any requests are swapped, prioritized swapped requests.
  774. if not self.swapped:
  775. prefills = self._schedule_prefills(budget,
  776. curr_loras,
  777. enable_chunking=False)
  778. # Don't schedule decodes if prefills are scheduled.
  779. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
  780. # only contains decode requests, not chunked prefills.
  781. if len(prefills.seq_groups) == 0:
  782. running_scheduled = self._schedule_running(budget,
  783. curr_loras,
  784. enable_chunking=False)
  785. # If any sequence group is preempted, do not swap in any sequence
  786. # group. because it means there's no slot for new running requests.
  787. if len(running_scheduled.preempted) + len(
  788. running_scheduled.swapped_out) == 0:
  789. swapped_in = self._schedule_swapped(budget, curr_loras)
  790. assert (budget.num_batched_tokens <=
  791. self.scheduler_config.max_num_batched_tokens)
  792. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  793. # Update waiting requests.
  794. self.waiting.extendleft(running_scheduled.preempted)
  795. # Update new running requests.
  796. if len(prefills.seq_groups) > 0:
  797. self.running.extend([s.seq_group for s in prefills.seq_groups])
  798. self.running.extend(running_scheduled.decode_seq_groups_list)
  799. if len(swapped_in.decode_seq_groups) > 0:
  800. self.running.extend(
  801. [s.seq_group for s in swapped_in.decode_seq_groups])
  802. # Update swapped requests.
  803. self.swapped.extend(running_scheduled.swapped_out)
  804. preempted = (len(running_scheduled.preempted) +
  805. len(running_scheduled.swapped_out))
  806. # There should be no prefill from running queue because this policy
  807. # doesn't allow chunked prefills.
  808. assert len(running_scheduled.prefill_seq_groups) == 0
  809. assert len(swapped_in.prefill_seq_groups) == 0
  810. # Merge lists
  811. num_prefill_groups = len(prefills.seq_groups)
  812. if num_prefill_groups > 0:
  813. scheduled_seq_groups = prefills.seq_groups
  814. scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
  815. else:
  816. scheduled_seq_groups = running_scheduled.decode_seq_groups
  817. scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
  818. blocks_to_copy = running_scheduled.blocks_to_copy
  819. blocks_to_copy.extend(swapped_in.blocks_to_copy)
  820. ignored_seq_groups = prefills.ignored_seq_groups
  821. ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
  822. return SchedulerOutputs(
  823. scheduled_seq_groups=scheduled_seq_groups,
  824. num_prefill_groups=num_prefill_groups,
  825. num_batched_tokens=budget.num_batched_tokens,
  826. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  827. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  828. blocks_to_copy=blocks_to_copy,
  829. ignored_seq_groups=ignored_seq_groups,
  830. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  831. running_queue_size=len(self.running),
  832. preempted=preempted,
  833. )
  834. def _schedule_chunked_prefill(self) -> SchedulerOutputs:
  835. """Schedule queued requests.
  836. Chunked prefill allows to chunk prefill requests, batch them together
  837. with decode requests. This policy 1. schedule as many decoding requests
  838. as possible. 2. schedule chunked prefill requests that are not
  839. finished. 3. schedule swapped request. 4. schedule new prefill
  840. requests.
  841. The policy can sustain the high GPU utilization because it can put
  842. prefill and decodes requests to the same batch, while it improves
  843. inter token latency because decodes requests don't need to be blocked
  844. by prefill requests.
  845. """
  846. budget = SchedulingBudget(
  847. token_budget=self.scheduler_config.max_num_batched_tokens,
  848. max_num_seqs=self.scheduler_config.max_num_seqs,
  849. )
  850. curr_loras: Set[int] = set()
  851. prefills = SchedulerPrefillOutputs.create_empty()
  852. swapped_in = SchedulerSwappedInOutputs.create_empty()
  853. # Decoding should be always scheduled first by fcfs.
  854. running_scheduled = self._schedule_running(budget,
  855. curr_loras,
  856. enable_chunking=True)
  857. # Schedule swapped out requests.
  858. # If preemption happens, it means we don't have space for swap-in.
  859. if len(running_scheduled.preempted) + len(
  860. running_scheduled.swapped_out) == 0:
  861. swapped_in = self._schedule_swapped(budget, curr_loras)
  862. # Schedule new prefills.
  863. prefills = self._schedule_prefills(budget,
  864. curr_loras,
  865. enable_chunking=True)
  866. assert (budget.num_batched_tokens <=
  867. self.scheduler_config.max_num_batched_tokens)
  868. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  869. # Update waiting requests.
  870. self.waiting.extendleft(running_scheduled.preempted)
  871. # Update new running requests.
  872. self.running.extend([s.seq_group for s in prefills.seq_groups])
  873. self.running.extend(
  874. [s.seq_group for s in running_scheduled.decode_seq_groups])
  875. self.running.extend(
  876. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  877. self.running.extend(
  878. [s.seq_group for s in swapped_in.decode_seq_groups])
  879. self.running.extend(
  880. [s.seq_group for s in swapped_in.prefill_seq_groups])
  881. # Update swapped requests.
  882. self.swapped.extend(running_scheduled.swapped_out)
  883. return SchedulerOutputs(
  884. scheduled_seq_groups=(prefills.seq_groups +
  885. running_scheduled.prefill_seq_groups +
  886. swapped_in.prefill_seq_groups +
  887. running_scheduled.decode_seq_groups +
  888. swapped_in.decode_seq_groups),
  889. num_prefill_groups=(len(prefills.seq_groups) +
  890. len(swapped_in.prefill_seq_groups) +
  891. len(running_scheduled.prefill_seq_groups)),
  892. num_batched_tokens=budget.num_batched_tokens,
  893. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  894. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  895. blocks_to_copy=running_scheduled.blocks_to_copy +
  896. swapped_in.blocks_to_copy,
  897. ignored_seq_groups=prefills.ignored_seq_groups +
  898. swapped_in.infeasible_seq_groups,
  899. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  900. running_queue_size=len(self.running),
  901. preempted=(len(running_scheduled.preempted) +
  902. len(running_scheduled.swapped_out)),
  903. )
  904. def _schedule(self) -> SchedulerOutputs:
  905. """Schedule queued requests."""
  906. if self.scheduler_config.chunked_prefill_enabled:
  907. return self._schedule_chunked_prefill()
  908. else:
  909. return self._schedule_default()
  910. def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
  911. """Determine whether or not we have enough space in the KV cache to
  912. continue generation of the sequence group.
  913. """
  914. # It is True only for testing case to trigger artificial preemption.
  915. if (self.enable_artificial_preemption
  916. and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
  917. and self.artificial_preempt_cnt > 0):
  918. self.artificial_preempt_cnt -= 1
  919. return False
  920. # Appending slots only occurs in decoding.
  921. is_prefill = False
  922. return self.block_manager.can_append_slots(
  923. seq_group=seq_group,
  924. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  925. )
  926. def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
  927. no_beam_search = (seq_group.sampling_params.best_of == 1
  928. and not seq_group.sampling_params.use_beam_search)
  929. return no_beam_search
  930. def schedule(
  931. self
  932. ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
  933. # Schedule sequence groups.
  934. # This function call changes the internal states of the scheduler
  935. # such as self.running, self.swapped, and self.waiting.
  936. scheduler_outputs = self._schedule()
  937. now = time.time()
  938. if not self.cache_config.enable_prefix_caching:
  939. common_computed_block_nums = []
  940. allow_async_output_proc: bool = self.use_async_output_proc
  941. # Create input data structures.
  942. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  943. for i, scheduled_seq_group in enumerate(
  944. scheduler_outputs.scheduled_seq_groups):
  945. seq_group = scheduled_seq_group.seq_group
  946. token_chunk_size = scheduled_seq_group.token_chunk_size
  947. seq_group.maybe_set_first_scheduled_time(now)
  948. seq_group_metadata = self._seq_group_metadata_cache[
  949. self.cache_id].get_object()
  950. seq_group_metadata.seq_data.clear()
  951. seq_group_metadata.block_tables.clear()
  952. # seq_id -> SequenceData
  953. seq_data: Dict[int, SequenceData] = {}
  954. # seq_id -> physical block numbers
  955. block_tables: Dict[int, List[int]] = {}
  956. if seq_group.is_encoder_decoder():
  957. # Encoder associated with SequenceGroup
  958. encoder_seq_data = seq_group.get_encoder_seq().data
  959. # Block table for cross-attention
  960. # Also managed at SequenceGroup level
  961. cross_block_table = self.block_manager.get_cross_block_table(
  962. seq_group)
  963. else:
  964. encoder_seq_data = None
  965. cross_block_table = None
  966. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  967. seq_id = seq.seq_id
  968. seq_data[seq_id] = seq.data
  969. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  970. self.block_manager.access_all_blocks_in_seq(seq, now)
  971. if self.cache_config.enable_prefix_caching:
  972. common_computed_block_nums = (
  973. self.block_manager.get_common_computed_block_ids(
  974. seq_group.get_seqs(status=SequenceStatus.RUNNING)))
  975. do_sample = True
  976. is_prompt = seq_group.is_prefill()
  977. # We should send the metadata to workers when the first prefill
  978. # is sent. Subsequent requests could be chunked prefill or decode.
  979. is_first_prefill = False
  980. if is_prompt:
  981. seqs = seq_group.get_seqs()
  982. # Prefill has only 1 sequence.
  983. assert len(seqs) == 1
  984. num_computed_tokens = seqs[0].data.get_num_computed_tokens()
  985. is_first_prefill = num_computed_tokens == 0
  986. # In the next iteration, all prompt tokens are not computed.
  987. # It means the prefill is chunked, and we don't need sampling.
  988. # NOTE: We use get_len instead of get_prompt_len because when
  989. # a sequence is preempted, prefill includes previous generated
  990. # output tokens.
  991. if (token_chunk_size + num_computed_tokens <
  992. seqs[0].data.get_len()):
  993. do_sample = False
  994. # It assumes the scheduled_seq_groups is ordered by
  995. # prefill < decoding.
  996. if is_first_prefill or not self.scheduler_config.send_delta_data:
  997. seq_group_metadata = SequenceGroupMetadata(
  998. request_id=seq_group.request_id,
  999. is_prompt=is_prompt,
  1000. seq_data=seq_data,
  1001. sampling_params=seq_group.sampling_params,
  1002. block_tables=block_tables,
  1003. do_sample=do_sample,
  1004. pooling_params=seq_group.pooling_params,
  1005. token_chunk_size=token_chunk_size,
  1006. lora_request=seq_group.lora_request,
  1007. computed_block_nums=common_computed_block_nums,
  1008. encoder_seq_data=encoder_seq_data,
  1009. cross_block_table=cross_block_table,
  1010. state=seq_group.state,
  1011. # `multi_modal_data` will only be present for the 1st comm
  1012. # between engine and worker.
  1013. # the subsequent comms can still use delta, but
  1014. # `multi_modal_data` will be None.
  1015. multi_modal_data=seq_group.multi_modal_data
  1016. if scheduler_outputs.num_prefill_groups > 0 else None,
  1017. prompt_adapter_request=seq_group.prompt_adapter_request,
  1018. )
  1019. else:
  1020. # When SPMD mode is enabled, we only send delta data except for
  1021. # the first request to reduce serialization cost.
  1022. seq_data_delta = {}
  1023. for id, data in seq_data.items():
  1024. seq_data_delta[id] = data.get_delta_and_reset()
  1025. seq_group_metadata = SequenceGroupMetadataDelta(
  1026. seq_data_delta,
  1027. seq_group.request_id,
  1028. block_tables,
  1029. is_prompt,
  1030. do_sample=do_sample,
  1031. token_chunk_size=token_chunk_size,
  1032. computed_block_nums=common_computed_block_nums,
  1033. )
  1034. seq_group_metadata_list.append(seq_group_metadata)
  1035. if allow_async_output_proc:
  1036. allow_async_output_proc = self._allow_async_output_proc(
  1037. seq_group)
  1038. # Now that the batch has been created, we can assume all blocks in the
  1039. # batch will have been computed before the next scheduling invocation.
  1040. # This is because the engine assumes that a failure in model execution
  1041. # will crash the Aphrodite instance / will not retry.
  1042. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
  1043. self.block_manager.mark_blocks_as_computed(
  1044. scheduled_seq_group.seq_group,
  1045. scheduled_seq_group.token_chunk_size)
  1046. self._seq_group_metadata_cache[self.next_cache_id].reset()
  1047. # Move to next cache (if exists)
  1048. self.cache_id = self.next_cache_id
  1049. # Return results
  1050. return (seq_group_metadata_list, scheduler_outputs,
  1051. allow_async_output_proc)
  1052. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  1053. self.block_manager.fork(parent_seq, child_seq)
  1054. def free_seq(self, seq: Sequence) -> None:
  1055. """Free a sequence from a block table."""
  1056. self.block_manager.free(seq)
  1057. def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
  1058. """Free finished seqs in a sequence group."""
  1059. for seq in seq_group.get_seqs():
  1060. if seq.is_finished():
  1061. self.free_seq(seq)
  1062. def free_finished_seq_groups(self) -> None:
  1063. remaining: Deque[SequenceGroup] = deque()
  1064. for seq_group in self.running:
  1065. if seq_group.is_finished():
  1066. # Free cross-attention block table, if it exists
  1067. self._free_seq_group_cross_attn_blocks(seq_group)
  1068. # Add the finished requests to the finished requests list.
  1069. # This list will be used to update the Mamba cache in the
  1070. # next step.
  1071. self._finished_requests_ids.append(seq_group.request_id)
  1072. else:
  1073. remaining.append(seq_group)
  1074. # Free finished seqs
  1075. self._free_finished_seqs(seq_group)
  1076. self.running = remaining
  1077. # Handle async stopped sequence groups
  1078. # (ones that reached max model len)
  1079. if self._async_stopped:
  1080. for seq_group in self._async_stopped:
  1081. self._free_seq_group_cross_attn_blocks(seq_group)
  1082. self._finished_requests_ids.append(seq_group.request_id)
  1083. # Free finished seqs
  1084. self._free_finished_seqs(seq_group)
  1085. self._async_stopped.clear()
  1086. def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
  1087. self.block_manager.allocate(seq_group)
  1088. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  1089. seq.status = SequenceStatus.RUNNING
  1090. def _append_slots(
  1091. self,
  1092. seq_group: SequenceGroup,
  1093. blocks_to_copy: List[Tuple[int, int]],
  1094. ) -> None:
  1095. """Appends new slots to the sequences in the given sequence group.
  1096. Args:
  1097. seq_group (SequenceGroup): The sequence group containing the
  1098. sequences to append slots to.
  1099. blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
  1100. ints, the first int is the source block index, and the second
  1101. int is the destination block index. This list is updated with
  1102. the new source and destination block indices for the appended
  1103. slots.
  1104. """
  1105. num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
  1106. seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
  1107. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  1108. cows = self.block_manager.append_slots(seq, num_lookahead_slots)
  1109. if len(cows) > 0:
  1110. blocks_to_copy.extend(cows)
  1111. def _preempt(
  1112. self,
  1113. seq_group: SequenceGroup,
  1114. blocks_to_swap_out: List[Tuple[int, int]],
  1115. preemption_mode: Optional[PreemptionMode] = None,
  1116. ) -> PreemptionMode:
  1117. # If preemption mode is not specified, we determine the mode as follows:
  1118. # We use recomputation by default since it incurs lower overhead than
  1119. # swapping. However, when the sequence group has multiple sequences
  1120. # (e.g., beam search), recomputation is not currently supported. In
  1121. # such a case, we use swapping instead.
  1122. # FIXME: This makes our scheduling policy a bit bizarre.
  1123. # As swapped sequences are prioritized over waiting sequences,
  1124. # sequence groups with multiple sequences are implicitly prioritized
  1125. # over sequence groups with a single sequence.
  1126. # TODO: Support recomputation for sequence groups with multiple
  1127. # sequences. This may require a more sophisticated CUDA kernel.
  1128. if self.user_specified_preemption_mode is None:
  1129. if seq_group.get_max_num_running_seqs() == 1:
  1130. preemption_mode = PreemptionMode.RECOMPUTE
  1131. else:
  1132. preemption_mode = PreemptionMode.SWAP
  1133. elif self.user_specified_preemption_mode == "swap":
  1134. preemption_mode = PreemptionMode.SWAP
  1135. else:
  1136. preemption_mode = PreemptionMode.RECOMPUTE
  1137. if self.num_cumulative_preemption % 50 == 0:
  1138. logger.warning(
  1139. f"Sequence group {seq_group.request_id} is preempted by "
  1140. f"{preemption_mode} mode because there is "
  1141. "not enough KV cache space. This can affect the end-to-end "
  1142. "performance. Increase gpu_memory_utilization or "
  1143. "tensor_parallel_size to provide more KV cache memory. "
  1144. "total_num_cumulative_preemption="
  1145. f"{self.num_cumulative_preemption + 1}")
  1146. self.num_cumulative_preemption += 1
  1147. if preemption_mode == PreemptionMode.RECOMPUTE:
  1148. self._preempt_by_recompute(seq_group)
  1149. elif preemption_mode == PreemptionMode.SWAP:
  1150. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  1151. else:
  1152. raise AssertionError("Invalid preemption mode.")
  1153. return preemption_mode
  1154. def _preempt_by_recompute(
  1155. self,
  1156. seq_group: SequenceGroup,
  1157. ) -> None:
  1158. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  1159. assert len(seqs) == 1
  1160. for seq in seqs:
  1161. seq.status = SequenceStatus.WAITING
  1162. self.free_seq(seq)
  1163. seq.reset_state_for_recompute()
  1164. def _preempt_by_swap(
  1165. self,
  1166. seq_group: SequenceGroup,
  1167. blocks_to_swap_out: List[Tuple[int, int]],
  1168. ) -> None:
  1169. self._swap_out(seq_group, blocks_to_swap_out)
  1170. def _swap_in(
  1171. self,
  1172. seq_group: SequenceGroup,
  1173. blocks_to_swap_in: List[Tuple[int, int]],
  1174. ) -> None:
  1175. mapping = self.block_manager.swap_in(seq_group)
  1176. blocks_to_swap_in.extend(mapping)
  1177. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  1178. seq.status = SequenceStatus.RUNNING
  1179. def _swap_out(
  1180. self,
  1181. seq_group: SequenceGroup,
  1182. blocks_to_swap_out: List[Tuple[int, int]],
  1183. ) -> None:
  1184. if not self.block_manager.can_swap_out(seq_group):
  1185. # FIXME: Abort the sequence group instead of aborting the
  1186. # entire engine.
  1187. raise RuntimeError(
  1188. "Aborted due to the lack of CPU swap space. Please increase "
  1189. "the swap space to avoid this error.")
  1190. mapping = self.block_manager.swap_out(seq_group)
  1191. blocks_to_swap_out.extend(mapping)
  1192. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  1193. seq.status = SequenceStatus.SWAPPED
  1194. def _passed_delay(self, now: float) -> bool:
  1195. if self.prev_prompt:
  1196. self.last_prompt_latency = now - self.prev_time
  1197. self.prev_time, self.prev_prompt = now, False
  1198. # Delay scheduling prompts to let waiting queue fill up
  1199. if self.scheduler_config.delay_factor > 0 and self.waiting:
  1200. earliest_arrival_time = min(
  1201. [e.metrics.arrival_time for e in self.waiting])
  1202. passed_delay = (
  1203. (now - earliest_arrival_time) >
  1204. (self.scheduler_config.delay_factor * self.last_prompt_latency)
  1205. or not self.running)
  1206. else:
  1207. passed_delay = True
  1208. return passed_delay
  1209. def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
  1210. """The number of slots to allocate per sequence per step, beyond known
  1211. token ids. Speculative decoding uses these slots to store KV activations
  1212. of tokens which may or may not be accepted.
  1213. Speculative decoding does not yet support prefill, so we do not perform
  1214. lookahead allocation for prefill.
  1215. """
  1216. if is_prefill:
  1217. return 0
  1218. return self.scheduler_config.num_lookahead_slots
  1219. def _get_num_new_tokens(self, seq_group: SequenceGroup,
  1220. status: SequenceStatus, enable_chunking: bool,
  1221. budget: SchedulingBudget) -> int:
  1222. """Get the next new tokens to compute for a given sequence group
  1223. that's in a given `status`.
  1224. The API could chunk the number of tokens to compute based on `budget`
  1225. if `enable_chunking` is True. If a sequence group has multiple
  1226. sequences (e.g., running beam search), it means it is in decoding
  1227. phase, so chunking doesn't happen.
  1228. Returns 0 if the new token cannot be computed due to token budget.
  1229. """
  1230. num_new_tokens = 0
  1231. seqs = seq_group.get_seqs(status=status)
  1232. for seq in seqs:
  1233. num_new_tokens += seq.get_num_new_tokens()
  1234. assert num_new_tokens > 0
  1235. # Chunk if a running request cannot fit in the given budget.
  1236. # If number of seq > 1, it means it is doing beam search
  1237. # in a decode phase. Do not chunk.
  1238. if enable_chunking and len(seqs) == 1:
  1239. remaining_token_budget = budget.remaining_token_budget()
  1240. if self.cache_config.enable_prefix_caching:
  1241. # When prefix caching is enabled, we always allocate
  1242. # the number of new tokens that is dividable by the block size
  1243. # to avoid partial block matching.
  1244. block_size = self.cache_config.block_size
  1245. reminder = budget.token_budget % block_size
  1246. if reminder != 0:
  1247. raise ValueError("When enabling chunked prefill and "
  1248. "prefix caching, max_num_batched_tokens "
  1249. "(chunk size) must be dividable by "
  1250. "block size, but got chunk_size "
  1251. f"({budget.token_budget}) % block_size "
  1252. f"({block_size}) = {reminder}")
  1253. if remaining_token_budget < num_new_tokens:
  1254. num_new_tokens = (remaining_token_budget //
  1255. block_size) * block_size
  1256. else:
  1257. num_new_tokens = min(num_new_tokens, remaining_token_budget)
  1258. return num_new_tokens