scheduler.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469
  1. import enum
  2. import os
  3. import random
  4. import time
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
  8. Tuple, Union)
  9. from loguru import logger
  10. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  11. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  12. SequenceGroupMetadata,
  13. SequenceGroupMetadataDelta,
  14. SequenceStatus)
  15. from aphrodite.common.utils import Device, PyObjectCache
  16. from aphrodite.lora.request import LoRARequest
  17. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  18. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  19. # Test-only. If configured, decode is preempted with
  20. # ARTIFICIAL_PREEMPTION_PROB% probability.
  21. ENABLE_ARTIFICIAL_PREEMPT = bool(
  22. os.getenv("APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
  23. ARTIFICIAL_PREEMPTION_PROB = 0.5
  24. ARTIFICIAL_PREEMPTION_MAX_CNT = 500
  25. class PreemptionMode(enum.Enum):
  26. """Preemption modes.
  27. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  28. and swap them back in when the sequences are resumed.
  29. 2. Recomputation: Discard the blocks of the preempted sequences and
  30. recompute them when the sequences are resumed, treating the sequences as
  31. new prompts.
  32. """
  33. SWAP = enum.auto()
  34. RECOMPUTE = enum.auto()
  35. @dataclass
  36. class SchedulingBudget:
  37. """The available slots for scheduling.
  38. TODO: Right now, the budget is request_id-aware meaning it can ignore
  39. budget update from the same request_id. It is because in normal scheduling
  40. path, we update RUNNING num_seqs ahead of time, meaning it could be
  41. updated more than once when scheduling RUNNING requests. Since this won't
  42. happen if we only have chunked prefill scheduling, we can remove this
  43. feature from the API when chunked prefill is enabled by default.
  44. """
  45. token_budget: int
  46. max_num_seqs: int
  47. _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
  48. _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
  49. _num_batched_tokens: int = 0
  50. _num_curr_seqs: int = 0
  51. def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
  52. assert num_new_tokens != 0
  53. assert num_new_seqs != 0
  54. return (self.num_batched_tokens + num_new_tokens <= self.token_budget
  55. and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
  56. def remaining_token_budget(self):
  57. return self.token_budget - self.num_batched_tokens
  58. def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
  59. if req_id in self._request_ids_num_batched_tokens:
  60. return
  61. self._request_ids_num_batched_tokens.add(req_id)
  62. self._num_batched_tokens += num_batched_tokens
  63. def subtract_num_batched_tokens(self, req_id: str,
  64. num_batched_tokens: int):
  65. if req_id in self._request_ids_num_batched_tokens:
  66. self._request_ids_num_batched_tokens.remove(req_id)
  67. self._num_batched_tokens -= num_batched_tokens
  68. def add_num_seqs(self, req_id: str, num_curr_seqs: int):
  69. if req_id in self._request_ids_num_curr_seqs:
  70. return
  71. self._request_ids_num_curr_seqs.add(req_id)
  72. self._num_curr_seqs += num_curr_seqs
  73. def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
  74. if req_id in self._request_ids_num_curr_seqs:
  75. self._request_ids_num_curr_seqs.remove(req_id)
  76. self._num_curr_seqs -= num_curr_seqs
  77. @property
  78. def num_batched_tokens(self):
  79. return self._num_batched_tokens
  80. @property
  81. def num_curr_seqs(self):
  82. return self._num_curr_seqs
  83. @dataclass
  84. class ScheduledSequenceGroup:
  85. # A sequence group that's scheduled.
  86. seq_group: SequenceGroup
  87. # The total chunk size (number of tokens) to process for next iteration.
  88. # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
  89. # chunked, it can be smaller than that.
  90. token_chunk_size: int
  91. @dataclass
  92. class SchedulerOutputs:
  93. """The scheduling decision made from a scheduler."""
  94. # Scheduled sequence groups.
  95. scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
  96. # Number of prefill groups scheduled.
  97. num_prefill_groups: int
  98. # Total number of batched tokens.
  99. num_batched_tokens: int
  100. # Blocks to swap in. List of CPU -> GPU block number.
  101. blocks_to_swap_in: List[Tuple[int, int]]
  102. # Blocks to swap out. List of GPU -> CPU block number.
  103. blocks_to_swap_out: List[Tuple[int, int]]
  104. # Blocks to copy. Source to dest block.
  105. blocks_to_copy: List[Tuple[int, int]]
  106. # Sequence groups that are going to be ignored.
  107. ignored_seq_groups: List[SequenceGroup]
  108. # The number of slots for lookahead decoding.
  109. num_lookahead_slots: int
  110. # The number of requests in the running queue
  111. running_queue_size: int
  112. preempted: int
  113. def __post_init__(self):
  114. # Swap in and swap out should never happen at the same time.
  115. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
  116. self.num_loras: int = len(self.lora_requests)
  117. if self.num_loras > 0:
  118. self._sort_by_lora_ids()
  119. self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
  120. def is_empty(self) -> bool:
  121. # NOTE: We do not consider the ignored sequence groups.
  122. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  123. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  124. def _sort_by_lora_ids(self):
  125. self.scheduled_seq_groups = sorted(
  126. self.scheduled_seq_groups,
  127. key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
  128. @property
  129. def lora_requests(self) -> Set[LoRARequest]:
  130. return {
  131. g.seq_group.lora_request
  132. for g in self.scheduled_seq_groups
  133. if g.seq_group.lora_request is not None
  134. }
  135. @property
  136. def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
  137. return {
  138. g.seq_group.prompt_adapter_request
  139. for g in self.scheduled_seq_groups
  140. if g.seq_group.prompt_adapter_request is not None
  141. }
  142. @dataclass
  143. class SchedulerRunningOutputs:
  144. """The requests that are scheduled from a running queue.
  145. Could contain prefill (prefill that's chunked) or decodes. If there's not
  146. enough memory, it can be preempted (for recompute) or swapped out.
  147. """
  148. # Selected sequences that are running and in a decoding phase.
  149. decode_seq_groups: List[ScheduledSequenceGroup]
  150. # Selected sequences that are running and in a prefill phase.
  151. # I.e., it means the prefill has been chunked.
  152. prefill_seq_groups: List[ScheduledSequenceGroup]
  153. # The preempted sequences.
  154. preempted: List[SequenceGroup]
  155. # Sequences that are swapped out.
  156. swapped_out: List[SequenceGroup]
  157. # The blocks to swap out.
  158. blocks_to_swap_out: List[Tuple[int, int]]
  159. # The blocks to copy.
  160. blocks_to_copy: List[Tuple[int, int]]
  161. # The number of slots for lookahead decoding.
  162. num_lookahead_slots: int
  163. # Optimization for fast-access to seq_group lists
  164. decode_seq_groups_list: List[SequenceGroup]
  165. prefill_seq_groups_list: List[SequenceGroup]
  166. @classmethod
  167. def create_empty(cls) -> "SchedulerRunningOutputs":
  168. return SchedulerRunningOutputs(
  169. decode_seq_groups=[],
  170. prefill_seq_groups=[],
  171. preempted=[],
  172. swapped_out=[],
  173. blocks_to_swap_out=[],
  174. blocks_to_copy=[],
  175. num_lookahead_slots=0,
  176. decode_seq_groups_list=[],
  177. prefill_seq_groups_list=[],
  178. )
  179. @dataclass
  180. class SchedulerSwappedInOutputs:
  181. """The requests that are scheduled from a swap queue.
  182. Could contain prefill (prefill that's chunked) or decodes.
  183. """
  184. # Selected sequences that are going to be swapped in and is in a
  185. # decoding phase.
  186. decode_seq_groups: List[SequenceGroup]
  187. # Selected sequences that are going to be swapped in and in a prefill
  188. # phase. I.e., it means the prefill has been chunked.
  189. prefill_seq_groups: List[SequenceGroup]
  190. # The blocks to swap in.
  191. blocks_to_swap_in: List[Tuple[int, int]]
  192. # The blocks to copy.
  193. blocks_to_copy: List[Tuple[int, int]]
  194. # The number of slots for lookahead decoding.
  195. num_lookahead_slots: int
  196. # Infeasible sequence groups.
  197. infeasible_seq_groups: List[SequenceGroup]
  198. @classmethod
  199. def create_empty(cls) -> "SchedulerSwappedInOutputs":
  200. return SchedulerSwappedInOutputs(
  201. decode_seq_groups=[],
  202. prefill_seq_groups=[],
  203. blocks_to_swap_in=[],
  204. blocks_to_copy=[],
  205. num_lookahead_slots=0,
  206. infeasible_seq_groups=[],
  207. )
  208. @dataclass
  209. class SchedulerPrefillOutputs:
  210. """The requests that are scheduled from a waiting queue.
  211. Could contain a fresh prefill requests or preempted requests that need
  212. to be recomputed from scratch.
  213. """
  214. # Selected sequences for prefill.
  215. seq_groups: List[SequenceGroup]
  216. # Ignored sequence groups.
  217. ignored_seq_groups: List[SequenceGroup]
  218. num_lookahead_slots: int
  219. @classmethod
  220. def create_empty(cls) -> "SchedulerPrefillOutputs":
  221. return SchedulerPrefillOutputs(
  222. seq_groups=[],
  223. ignored_seq_groups=[],
  224. num_lookahead_slots=0,
  225. )
  226. def seq_group_metadata_builder():
  227. return SequenceGroupMetadata(request_id="",
  228. is_prompt=False,
  229. seq_data={},
  230. sampling_params=None,
  231. block_tables={})
  232. def scheduler_running_outputs_builder():
  233. return SchedulerRunningOutputs(decode_seq_groups=[],
  234. prefill_seq_groups=[],
  235. preempted=[],
  236. swapped_out=[],
  237. blocks_to_swap_out=[],
  238. blocks_to_copy=[],
  239. num_lookahead_slots=0,
  240. prefill_seq_groups_list=[],
  241. decode_seq_groups_list=[])
  242. def scheduled_seq_group_builder():
  243. return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
  244. class Scheduler:
  245. def __init__(
  246. self,
  247. scheduler_config: SchedulerConfig,
  248. cache_config: CacheConfig,
  249. lora_config: Optional[LoRAConfig],
  250. pipeline_parallel_size: int = 1,
  251. output_proc_callback: Optional[Callable] = None,
  252. ) -> None:
  253. self.scheduler_config = scheduler_config
  254. self.cache_config = cache_config
  255. # Note for LoRA scheduling: the current policy is extremely
  256. # simple and NOT fair. It can lead to starvation of some
  257. # LoRAs. This should be improved in the future.
  258. self.lora_config = lora_config
  259. version = "v1"
  260. if self.scheduler_config.use_v2_block_manager:
  261. version = "v2"
  262. if (self.scheduler_config.embedding_mode
  263. or self.scheduler_config.is_attention_free):
  264. version = "placeholder"
  265. BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
  266. version)
  267. num_gpu_blocks = cache_config.num_gpu_blocks
  268. if num_gpu_blocks:
  269. num_gpu_blocks //= pipeline_parallel_size
  270. num_cpu_blocks = cache_config.num_cpu_blocks
  271. if num_cpu_blocks:
  272. num_cpu_blocks //= pipeline_parallel_size
  273. # Create the block space manager.
  274. self.block_manager = BlockSpaceManagerImpl(
  275. block_size=self.cache_config.block_size,
  276. num_gpu_blocks=num_gpu_blocks,
  277. num_cpu_blocks=num_cpu_blocks,
  278. sliding_window=self.cache_config.sliding_window,
  279. enable_caching=self.cache_config.enable_prefix_caching)
  280. # Sequence groups in the WAITING state.
  281. # Contain new prefill or preempted requests.
  282. self.waiting: Deque[SequenceGroup] = deque()
  283. # Sequence groups in the RUNNING state.
  284. # Contain decode requests.
  285. self.running: Deque[SequenceGroup] = deque()
  286. # Sequence groups in the SWAPPED state.
  287. # Contain decode requests that are swapped out.
  288. self.swapped: Deque[SequenceGroup] = deque()
  289. # Sequence groups finished requests ids since last step iteration.
  290. # It lets the model know that any state associated with these requests
  291. # can and must be released after the current step.
  292. # This is used to evict the finished requests from the Mamba cache.
  293. self._finished_requests_ids: List[str] = list()
  294. # Time at previous scheduling step
  295. self.prev_time = 0.0
  296. # Did we schedule a prompt at previous step?
  297. self.prev_prompt = False
  298. # Latency of the last prompt step
  299. self.last_prompt_latency = 0.0
  300. # preemption mode, RECOMPUTE or SWAP
  301. self.user_specified_preemption_mode = scheduler_config.preemption_mode
  302. # The following field is test-only. It is used to inject artificial
  303. # preemption.
  304. self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
  305. self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
  306. if self.enable_artificial_preemption
  307. else 0)
  308. self.num_cumulative_preemption: int = 0
  309. # Used to cache python objects
  310. self._seq_group_metadata_cache: List[PyObjectCache] = []
  311. self._scheduler_running_outputs_cache: List[PyObjectCache] = []
  312. self._scheduled_seq_group_cache: List[PyObjectCache] = []
  313. # For async output processing, we need to swap cache buffers between
  314. # iterations. I.e. since the output processing is lagged one step,
  315. # we cannot reuse the cached objects immediately when the schedule()
  316. # is called again, but only when schedule() is called the second time.
  317. self.output_proc_callback = output_proc_callback
  318. self.use_async_output_proc = self.output_proc_callback is not None
  319. self.num_cache_iters = 2 if self.use_async_output_proc else 1
  320. self.cache_id = 0
  321. for i in range(self.num_cache_iters):
  322. self._seq_group_metadata_cache.append(
  323. PyObjectCache(seq_group_metadata_builder))
  324. self._scheduler_running_outputs_cache.append(
  325. PyObjectCache(scheduler_running_outputs_builder))
  326. self._scheduled_seq_group_cache.append(
  327. PyObjectCache(scheduled_seq_group_builder))
  328. # For async postprocessor, the extra decode run cannot be done
  329. # when the request reaches max_model_len. In this case, the request
  330. # will be stopped during schedule() call and added to this stop list
  331. # for processing and deallocation by the free_finished_seq_groups()
  332. self._async_stopped: List[SequenceGroup] = []
  333. @property
  334. def next_cache_id(self):
  335. return (self.cache_id + 1) % self.num_cache_iters
  336. @property
  337. def lora_enabled(self) -> bool:
  338. return bool(self.lora_config)
  339. @property
  340. def num_decoding_tokens_per_seq(self) -> int:
  341. """The number of new tokens."""
  342. return 1
  343. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  344. # Add sequence groups to the waiting queue.
  345. self.waiting.append(seq_group)
  346. def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
  347. # Add sequence groups to the running queue.
  348. # Only for testing purposes.
  349. self.running.append(seq_group)
  350. def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
  351. # Add sequence groups to the swapped queue.
  352. # Only for testing purposes.
  353. self.swapped.append(seq_group)
  354. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  355. """Aborts a sequence group with the given ID.
  356. Check if the sequence group with the given ID
  357. is present in any of the state queue.
  358. If present, remove the sequence group from the state queue.
  359. Also, if any of the sequences in the sequence group is not finished,
  360. free the sequence with status `FINISHED_ABORTED`.
  361. Otherwise, do nothing.
  362. Args:
  363. request_id: The ID(s) of the sequence group to abort.
  364. """
  365. if isinstance(request_id, str):
  366. request_id = (request_id, )
  367. request_ids = set(request_id)
  368. for state_queue in [self.waiting, self.running, self.swapped]:
  369. aborted_groups: List[SequenceGroup] = []
  370. for seq_group in state_queue:
  371. if not request_ids:
  372. # Using 'break' here may add two extra iterations,
  373. # but is acceptable to reduce complexity.
  374. break
  375. if seq_group.request_id in request_ids:
  376. # Appending aborted group into pending list.
  377. aborted_groups.append(seq_group)
  378. request_ids.remove(seq_group.request_id)
  379. for aborted_group in aborted_groups:
  380. # Remove the sequence group from the state queue.
  381. state_queue.remove(aborted_group)
  382. # Remove the aborted request from the Mamba cache.
  383. self._finished_requests_ids.append(aborted_group.request_id)
  384. for seq in aborted_group.get_seqs():
  385. if seq.is_finished():
  386. continue
  387. seq.status = SequenceStatus.FINISHED_ABORTED
  388. self.free_seq(seq)
  389. self._free_seq_group_cross_attn_blocks(aborted_group)
  390. def _free_seq_group_cross_attn_blocks(
  391. self,
  392. seq_group: SequenceGroup,
  393. ) -> None:
  394. """
  395. Free a sequence group from a cross-attention block table.
  396. Has no effect on decoder-only models.
  397. """
  398. if seq_group.is_encoder_decoder():
  399. self.block_manager.free_cross(seq_group)
  400. def has_unfinished_seqs(self) -> bool:
  401. return len(self.waiting) != 0 or len(self.running) != 0 or len(
  402. self.swapped) != 0
  403. def get_prefix_cache_hit_rate(self, device: Device) -> float:
  404. return self.block_manager.get_prefix_cache_hit_rate(device)
  405. def get_num_unfinished_seq_groups(self) -> int:
  406. return len(self.waiting) + len(self.running) + len(self.swapped)
  407. def get_and_reset_finished_requests_ids(self) -> List[str]:
  408. """Flushes the list of request ids of previously finished seq_groups."""
  409. finished_requests_ids = self._finished_requests_ids
  410. self._finished_requests_ids = list()
  411. return finished_requests_ids
  412. def _schedule_running(
  413. self,
  414. budget: SchedulingBudget,
  415. curr_loras: Optional[Set[int]],
  416. enable_chunking: bool = False,
  417. ) -> SchedulerRunningOutputs:
  418. """Schedule sequence groups that are running.
  419. Running queue should include decode and chunked prefill requests.
  420. Args:
  421. budget: The scheduling budget. The argument is in-place updated
  422. when any decodes are preempted.
  423. curr_loras: Currently batched lora request ids. The argument is
  424. in-place updated when any decodes are preempted.
  425. enable_chunking: If True, seq group can be chunked and only a
  426. chunked number of tokens are scheduled if
  427. `budget.num_batched_tokens` has not enough capacity to schedule
  428. all tokens.
  429. Returns:
  430. SchedulerRunningOutputs.
  431. """
  432. ret: SchedulerRunningOutputs = \
  433. self._scheduler_running_outputs_cache[self.cache_id].get_object()
  434. ret.blocks_to_swap_out.clear()
  435. ret.blocks_to_copy.clear()
  436. ret.decode_seq_groups.clear()
  437. ret.prefill_seq_groups.clear()
  438. ret.preempted.clear()
  439. ret.swapped_out.clear()
  440. ret.num_lookahead_slots = self._get_num_lookahead_slots(
  441. is_prefill=False)
  442. ret.decode_seq_groups_list.clear()
  443. ret.prefill_seq_groups_list.clear()
  444. # Blocks that need to be swapped or copied before model execution.
  445. blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
  446. blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
  447. decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
  448. prefill_seq_groups: List[
  449. ScheduledSequenceGroup] = ret.prefill_seq_groups
  450. preempted: List[SequenceGroup] = ret.preempted
  451. swapped_out: List[SequenceGroup] = ret.swapped_out
  452. running_queue = self.running
  453. assert len(self._async_stopped) == 0
  454. while running_queue:
  455. seq_group = running_queue[0]
  456. num_running_tokens = self._get_num_new_tokens(
  457. seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
  458. if num_running_tokens == 0:
  459. # No budget => Stop
  460. break
  461. running_queue.popleft()
  462. # With async postprocessor, an extra decode run is done
  463. # to process the final tokens. The check below avoids this extra
  464. # decode run when the model max len is reached, in order to avoid
  465. # a memory overflow.
  466. if self.use_async_output_proc and seq_group.seqs[0].get_len(
  467. ) > self.scheduler_config.max_model_len:
  468. self._async_stopped.append(seq_group)
  469. continue
  470. # NOTE: Preemption happens only when there is no available
  471. # slot to keep all the sequence groups in the RUNNING state.
  472. while not self._can_append_slots(seq_group):
  473. budget.subtract_num_batched_tokens(seq_group.request_id,
  474. num_running_tokens)
  475. num_running_seqs = seq_group.get_max_num_running_seqs()
  476. budget.subtract_num_seqs(seq_group.request_id,
  477. num_running_seqs)
  478. if (curr_loras is not None and seq_group.lora_int_id > 0
  479. and seq_group.lora_int_id in curr_loras):
  480. curr_loras.remove(seq_group.lora_int_id)
  481. # Determine victim sequence
  482. cont_loop = True
  483. if running_queue:
  484. # Preempt the lowest-priority sequence group.
  485. victim_seq_group = running_queue.pop()
  486. else:
  487. # No other sequence group can be preempted.
  488. # Preempt the current sequence group.
  489. # Note: This is also where we stop this loop
  490. # (since there is nothing else to preempt)
  491. victim_seq_group = seq_group
  492. cont_loop = False
  493. # With async postprocessor, before preempting a sequence
  494. # we need to ensure it has no pending async postprocessor
  495. do_preempt = True
  496. if self.use_async_output_proc:
  497. assert self.output_proc_callback is not None
  498. self.output_proc_callback(
  499. request_id=victim_seq_group.request_id)
  500. # It may be that the async pending "victim_seq_group"
  501. # becomes finished, in which case we simply free it.
  502. if victim_seq_group.is_finished():
  503. self._free_finished_seq_group(victim_seq_group)
  504. do_preempt = False
  505. # Do preemption
  506. if do_preempt:
  507. preempted_mode = self._preempt(victim_seq_group,
  508. blocks_to_swap_out)
  509. if preempted_mode == PreemptionMode.RECOMPUTE:
  510. preempted.append(victim_seq_group)
  511. else:
  512. swapped_out.append(victim_seq_group)
  513. if not cont_loop:
  514. break
  515. else:
  516. self._append_slots(seq_group, blocks_to_copy)
  517. is_prefill = seq_group.is_prefill()
  518. scheduled_seq_group: ScheduledSequenceGroup = \
  519. self._scheduled_seq_group_cache[self.cache_id].get_object()
  520. scheduled_seq_group.seq_group = seq_group
  521. if is_prefill:
  522. scheduled_seq_group.token_chunk_size = num_running_tokens
  523. prefill_seq_groups.append(scheduled_seq_group)
  524. ret.prefill_seq_groups_list.append(seq_group)
  525. else:
  526. scheduled_seq_group.token_chunk_size = 1
  527. decode_seq_groups.append(scheduled_seq_group)
  528. ret.decode_seq_groups_list.append(seq_group)
  529. budget.add_num_batched_tokens(seq_group.request_id,
  530. num_running_tokens)
  531. # OPTIMIZATION: Note that get_max_num_running_seqs is
  532. # expensive. For the default scheduling chase where
  533. # enable_chunking is False, num_seqs are updated before running
  534. # this method, so we don't have to update it again here.
  535. if enable_chunking:
  536. num_running_seqs = seq_group.get_max_num_running_seqs()
  537. budget.add_num_seqs(seq_group.request_id, num_running_seqs)
  538. if curr_loras is not None and seq_group.lora_int_id > 0:
  539. curr_loras.add(seq_group.lora_int_id)
  540. self._scheduler_running_outputs_cache[self.next_cache_id].reset()
  541. self._scheduled_seq_group_cache[self.next_cache_id].reset()
  542. return ret
  543. def _schedule_swapped(
  544. self,
  545. budget: SchedulingBudget,
  546. curr_loras: Optional[Set[int]],
  547. enable_chunking: bool = False,
  548. ) -> SchedulerSwappedInOutputs:
  549. """Schedule sequence groups that are swapped out.
  550. It schedules swapped requests as long as it fits `budget` and
  551. curr_loras <= max_lora from the scheduling config. The input arguments
  552. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  553. Args:
  554. budget: The scheduling budget. The argument is in-place updated
  555. when any requests are swapped in.
  556. curr_loras: Currently batched lora request ids. The argument is
  557. in-place updated when any requests are swapped in.
  558. enable_chunking: If True, seq group can be chunked and only a
  559. chunked number of tokens are scheduled if
  560. `budget.num_batched_tokens` has not enough capacity to schedule
  561. all tokens.
  562. Returns:
  563. SchedulerSwappedInOutputs.
  564. """
  565. # Blocks that need to be swapped or copied before model execution.
  566. blocks_to_swap_in: List[Tuple[int, int]] = []
  567. blocks_to_copy: List[Tuple[int, int]] = []
  568. decode_seq_groups: List[ScheduledSequenceGroup] = []
  569. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  570. infeasible_seq_groups: List[SequenceGroup] = []
  571. swapped_queue = self.swapped
  572. leftover_swapped: Deque[SequenceGroup] = deque()
  573. while swapped_queue:
  574. seq_group = swapped_queue[0]
  575. # If the sequence group cannot be swapped in, stop.
  576. is_prefill = seq_group.is_prefill()
  577. alloc_status = self.block_manager.can_swap_in(
  578. seq_group, self._get_num_lookahead_slots(is_prefill))
  579. if alloc_status == AllocStatus.LATER:
  580. break
  581. elif alloc_status == AllocStatus.NEVER:
  582. logger.warning(f"Failing the request {seq_group.request_id} "
  583. "because there's not enough kv cache blocks to "
  584. "run the entire sequence.")
  585. for seq in seq_group.get_seqs():
  586. seq.status = SequenceStatus.FINISHED_IGNORED
  587. infeasible_seq_groups.append(seq_group)
  588. swapped_queue.popleft()
  589. continue
  590. lora_int_id = 0
  591. if self.lora_enabled:
  592. lora_int_id = seq_group.lora_int_id
  593. assert curr_loras is not None
  594. assert self.lora_config is not None
  595. if (lora_int_id > 0 and (lora_int_id not in curr_loras)
  596. and len(curr_loras) >= self.lora_config.max_loras):
  597. # We don't have a space for another LoRA, so
  598. # we ignore this request for now.
  599. leftover_swapped.appendleft(seq_group)
  600. swapped_queue.popleft()
  601. continue
  602. # The total number of sequences in the RUNNING state should not
  603. # exceed the maximum number of sequences.
  604. num_new_seqs = seq_group.get_max_num_running_seqs()
  605. num_new_tokens = self._get_num_new_tokens(seq_group,
  606. SequenceStatus.SWAPPED,
  607. enable_chunking, budget)
  608. if (num_new_tokens == 0
  609. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  610. num_new_seqs=num_new_seqs)):
  611. break
  612. if lora_int_id > 0 and curr_loras is not None:
  613. curr_loras.add(lora_int_id)
  614. swapped_queue.popleft()
  615. self._swap_in(seq_group, blocks_to_swap_in)
  616. self._append_slots(seq_group, blocks_to_copy)
  617. is_prefill = seq_group.is_prefill()
  618. if is_prefill:
  619. prefill_seq_groups.append(
  620. ScheduledSequenceGroup(seq_group,
  621. token_chunk_size=num_new_tokens))
  622. else:
  623. decode_seq_groups.append(
  624. ScheduledSequenceGroup(seq_group, token_chunk_size=1))
  625. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  626. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  627. swapped_queue.extendleft(leftover_swapped)
  628. return SchedulerSwappedInOutputs(
  629. decode_seq_groups=decode_seq_groups,
  630. prefill_seq_groups=prefill_seq_groups,
  631. blocks_to_swap_in=blocks_to_swap_in,
  632. blocks_to_copy=blocks_to_copy,
  633. num_lookahead_slots=self._get_num_lookahead_slots(
  634. is_prefill=False),
  635. infeasible_seq_groups=infeasible_seq_groups,
  636. )
  637. def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
  638. if self.scheduler_config.chunked_prefill_enabled:
  639. prompt_limit = self.scheduler_config.max_model_len
  640. else:
  641. prompt_limit = min(self.scheduler_config.max_model_len,
  642. self.scheduler_config.max_num_batched_tokens)
  643. # Model is fine tuned with long context. Return the fine tuned max_len.
  644. if (seq_group.lora_request
  645. and seq_group.lora_request.long_lora_max_len):
  646. assert prompt_limit <= seq_group.lora_request.long_lora_max_len
  647. return seq_group.lora_request.long_lora_max_len
  648. else:
  649. return prompt_limit
  650. def _schedule_prefills(
  651. self,
  652. budget: SchedulingBudget,
  653. curr_loras: Optional[Set[int]],
  654. enable_chunking: bool = False,
  655. ) -> SchedulerPrefillOutputs:
  656. """Schedule sequence groups that are in prefill stage.
  657. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
  658. as a new prefill (that starts from beginning -> most recently generated
  659. tokens).
  660. It schedules waiting requests as long as it fits `budget` and
  661. curr_loras <= max_lora from the scheduling config. The input arguments
  662. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  663. Args:
  664. budget: The scheduling budget. The argument is in-place updated
  665. when any requests are scheduled.
  666. curr_loras: Currently batched lora request ids. The argument is
  667. in-place updated when any requests are scheduled.
  668. enable_chunking: If True, seq group can be chunked and only a
  669. chunked number of tokens are scheduled if
  670. `budget.num_batched_tokens` has not enough capacity to schedule
  671. all tokens.
  672. Returns:
  673. SchedulerPrefillOutputs.
  674. """
  675. ignored_seq_groups: List[SequenceGroup] = []
  676. seq_groups: List[SequenceGroup] = []
  677. waiting_queue = self.waiting
  678. leftover_waiting_sequences: Deque[SequenceGroup] = deque()
  679. while self._passed_delay(time.time()) and waiting_queue:
  680. seq_group = waiting_queue[0]
  681. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  682. assert len(waiting_seqs) == 1, (
  683. "Waiting sequence group should have only one prompt "
  684. "sequence.")
  685. num_new_tokens = self._get_num_new_tokens(seq_group,
  686. SequenceStatus.WAITING,
  687. enable_chunking, budget)
  688. if not enable_chunking:
  689. num_prompt_tokens = waiting_seqs[0].get_len()
  690. assert num_new_tokens == num_prompt_tokens
  691. prompt_limit = self._get_prompt_limit(seq_group)
  692. if num_new_tokens > prompt_limit:
  693. logger.warning(f"Input prompt ({num_new_tokens} tokens) is "
  694. f"too long and exceeds limit of {prompt_limit}")
  695. for seq in waiting_seqs:
  696. seq.status = SequenceStatus.FINISHED_IGNORED
  697. ignored_seq_groups.append(seq_group)
  698. waiting_queue.popleft()
  699. continue
  700. # If the sequence group cannot be allocated, stop.
  701. can_allocate = self.block_manager.can_allocate(seq_group)
  702. if can_allocate == AllocStatus.LATER:
  703. break
  704. elif can_allocate == AllocStatus.NEVER:
  705. logger.warning(f"Input prompt ({num_new_tokens} tokens) is "
  706. "too long and exceeds the capacity of "
  707. "block_manager")
  708. for seq in waiting_seqs:
  709. seq.status = SequenceStatus.FINISHED_IGNORED
  710. ignored_seq_groups.append(seq_group)
  711. waiting_queue.popleft()
  712. continue
  713. lora_int_id = 0
  714. if self.lora_enabled:
  715. lora_int_id = seq_group.lora_int_id
  716. assert curr_loras is not None
  717. assert self.lora_config is not None
  718. if (self.lora_enabled and lora_int_id > 0
  719. and lora_int_id not in curr_loras
  720. and len(curr_loras) >= self.lora_config.max_loras):
  721. # We don't have a space for another LoRA, so
  722. # we ignore this request for now.
  723. leftover_waiting_sequences.appendleft(seq_group)
  724. waiting_queue.popleft()
  725. continue
  726. num_new_seqs = seq_group.get_max_num_running_seqs()
  727. if (num_new_tokens == 0
  728. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  729. num_new_seqs=num_new_seqs)):
  730. break
  731. # Can schedule this request.
  732. if curr_loras is not None and lora_int_id > 0:
  733. curr_loras.add(lora_int_id)
  734. waiting_queue.popleft()
  735. self._allocate_and_set_running(seq_group)
  736. seq_group.init_multi_step(
  737. num_scheduler_steps=self._get_num_lookahead_slots(
  738. is_prefill=True) + 1)
  739. seq_groups.append(
  740. ScheduledSequenceGroup(seq_group=seq_group,
  741. token_chunk_size=num_new_tokens))
  742. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  743. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  744. # Queue requests that couldn't be scheduled.
  745. waiting_queue.extendleft(leftover_waiting_sequences)
  746. if len(seq_groups) > 0:
  747. self.prev_prompt = True
  748. return SchedulerPrefillOutputs(
  749. seq_groups=seq_groups,
  750. ignored_seq_groups=ignored_seq_groups,
  751. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
  752. def _schedule_default(self) -> SchedulerOutputs:
  753. """Schedule queued requests.
  754. The current policy is designed to optimize the throughput. First,
  755. it batches as many prefill requests as possible. And it schedules
  756. decodes. If there's a pressure on GPU memory, decode requests can
  757. be swapped or preempted.
  758. """
  759. # Include running requests to the budget.
  760. budget = SchedulingBudget(
  761. token_budget=self.scheduler_config.max_num_batched_tokens,
  762. max_num_seqs=self.scheduler_config.max_num_seqs,
  763. )
  764. # Make sure we include num running seqs before scheduling prefill,
  765. # so that we don't schedule beyond max_num_seqs for prefill.
  766. for seq_group in self.running:
  767. budget.add_num_seqs(seq_group.request_id,
  768. seq_group.get_max_num_running_seqs())
  769. curr_loras = set(
  770. seq_group.lora_int_id for seq_group in self.running
  771. if seq_group.lora_int_id > 0) if self.lora_enabled else None
  772. prefills = SchedulerPrefillOutputs.create_empty()
  773. running_scheduled = SchedulerRunningOutputs.create_empty()
  774. swapped_in = SchedulerSwappedInOutputs.create_empty()
  775. # If any requests are swapped, prioritized swapped requests.
  776. if not self.swapped:
  777. prefills = self._schedule_prefills(budget,
  778. curr_loras,
  779. enable_chunking=False)
  780. # Don't schedule decodes if prefills are scheduled.
  781. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
  782. # only contains decode requests, not chunked prefills.
  783. if len(prefills.seq_groups) == 0:
  784. running_scheduled = self._schedule_running(budget,
  785. curr_loras,
  786. enable_chunking=False)
  787. # If any sequence group is preempted, do not swap in any sequence
  788. # group. because it means there's no slot for new running requests.
  789. if len(running_scheduled.preempted) + len(
  790. running_scheduled.swapped_out) == 0:
  791. swapped_in = self._schedule_swapped(budget, curr_loras)
  792. assert (budget.num_batched_tokens <=
  793. self.scheduler_config.max_num_batched_tokens)
  794. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  795. # Update waiting requests.
  796. self.waiting.extendleft(running_scheduled.preempted)
  797. # Update new running requests.
  798. # By default, Aphrodite scheduler prioritizes prefills.
  799. # Once chunked prefill is enabled,
  800. # the policy is changed to prioritize decode requests.
  801. self.running.extend(
  802. [s.seq_group for s in swapped_in.decode_seq_groups])
  803. self.running.extend(
  804. [s.seq_group for s in swapped_in.prefill_seq_groups])
  805. self.running.extend(
  806. [s.seq_group for s in running_scheduled.decode_seq_groups])
  807. self.running.extend(
  808. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  809. self.running.extend([s.seq_group for s in prefills.seq_groups])
  810. # Update swapped requests.
  811. self.swapped.extend(running_scheduled.swapped_out)
  812. preempted = (len(running_scheduled.preempted) +
  813. len(running_scheduled.swapped_out))
  814. # There should be no prefill from running queue because this policy
  815. # doesn't allow chunked prefills.
  816. assert len(running_scheduled.prefill_seq_groups) == 0
  817. assert len(swapped_in.prefill_seq_groups) == 0
  818. # Merge lists
  819. num_prefill_groups = len(prefills.seq_groups)
  820. if num_prefill_groups > 0:
  821. scheduled_seq_groups = prefills.seq_groups
  822. scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
  823. else:
  824. scheduled_seq_groups = running_scheduled.decode_seq_groups
  825. scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
  826. blocks_to_copy = running_scheduled.blocks_to_copy
  827. blocks_to_copy.extend(swapped_in.blocks_to_copy)
  828. ignored_seq_groups = prefills.ignored_seq_groups
  829. ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
  830. return SchedulerOutputs(
  831. scheduled_seq_groups=scheduled_seq_groups,
  832. num_prefill_groups=num_prefill_groups,
  833. num_batched_tokens=budget.num_batched_tokens,
  834. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  835. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  836. blocks_to_copy=blocks_to_copy,
  837. ignored_seq_groups=ignored_seq_groups,
  838. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  839. running_queue_size=len(self.running),
  840. preempted=preempted,
  841. )
  842. def _schedule_chunked_prefill(self) -> SchedulerOutputs:
  843. """Schedule queued requests.
  844. Chunked prefill allows to chunk prefill requests, batch them together
  845. with decode requests. This policy 1. schedule as many decoding requests
  846. as possible. 2. schedule chunked prefill requests that are not
  847. finished. 3. schedule swapped request. 4. schedule new prefill
  848. requests.
  849. The policy can sustain the high GPU utilization because it can put
  850. prefill and decodes requests to the same batch, while it improves
  851. inter token latency because decodes requests don't need to be blocked
  852. by prefill requests.
  853. """
  854. budget = SchedulingBudget(
  855. token_budget=self.scheduler_config.max_num_batched_tokens,
  856. max_num_seqs=self.scheduler_config.max_num_seqs,
  857. )
  858. curr_loras: Set[int] = set()
  859. prefills = SchedulerPrefillOutputs.create_empty()
  860. swapped_in = SchedulerSwappedInOutputs.create_empty()
  861. # Decoding should be always scheduled first by fcfs.
  862. running_scheduled = self._schedule_running(budget,
  863. curr_loras,
  864. enable_chunking=True)
  865. # Schedule swapped out requests.
  866. # If preemption happens, it means we don't have space for swap-in.
  867. if len(running_scheduled.preempted) + len(
  868. running_scheduled.swapped_out) == 0:
  869. swapped_in = self._schedule_swapped(budget, curr_loras)
  870. # Schedule new prefills.
  871. prefills = self._schedule_prefills(budget,
  872. curr_loras,
  873. enable_chunking=True)
  874. assert (budget.num_batched_tokens <=
  875. self.scheduler_config.max_num_batched_tokens)
  876. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  877. # Update waiting requests.
  878. self.waiting.extendleft(running_scheduled.preempted)
  879. # Update new running requests.
  880. self.running.extend([s.seq_group for s in prefills.seq_groups])
  881. self.running.extend(
  882. [s.seq_group for s in running_scheduled.decode_seq_groups])
  883. self.running.extend(
  884. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  885. self.running.extend(
  886. [s.seq_group for s in swapped_in.decode_seq_groups])
  887. self.running.extend(
  888. [s.seq_group for s in swapped_in.prefill_seq_groups])
  889. # Update swapped requests.
  890. self.swapped.extend(running_scheduled.swapped_out)
  891. return SchedulerOutputs(
  892. scheduled_seq_groups=(prefills.seq_groups +
  893. running_scheduled.prefill_seq_groups +
  894. swapped_in.prefill_seq_groups +
  895. running_scheduled.decode_seq_groups +
  896. swapped_in.decode_seq_groups),
  897. num_prefill_groups=(len(prefills.seq_groups) +
  898. len(swapped_in.prefill_seq_groups) +
  899. len(running_scheduled.prefill_seq_groups)),
  900. num_batched_tokens=budget.num_batched_tokens,
  901. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  902. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  903. blocks_to_copy=running_scheduled.blocks_to_copy +
  904. swapped_in.blocks_to_copy,
  905. ignored_seq_groups=prefills.ignored_seq_groups +
  906. swapped_in.infeasible_seq_groups,
  907. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  908. running_queue_size=len(self.running),
  909. preempted=(len(running_scheduled.preempted) +
  910. len(running_scheduled.swapped_out)),
  911. )
  912. def _schedule(self) -> SchedulerOutputs:
  913. """Schedule queued requests."""
  914. if self.scheduler_config.chunked_prefill_enabled:
  915. return self._schedule_chunked_prefill()
  916. else:
  917. return self._schedule_default()
  918. def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
  919. """Determine whether or not we have enough space in the KV cache to
  920. continue generation of the sequence group.
  921. """
  922. # It is True only for testing case to trigger artificial preemption.
  923. if (self.enable_artificial_preemption
  924. and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
  925. and self.artificial_preempt_cnt > 0):
  926. self.artificial_preempt_cnt -= 1
  927. return False
  928. # Appending slots only occurs in decoding.
  929. is_prefill = False
  930. return self.block_manager.can_append_slots(
  931. seq_group=seq_group,
  932. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  933. )
  934. def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
  935. no_beam_search = (seq_group.sampling_params.best_of == 1
  936. and not seq_group.sampling_params.use_beam_search)
  937. return no_beam_search
  938. def schedule(
  939. self
  940. ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
  941. # Schedule sequence groups.
  942. # This function call changes the internal states of the scheduler
  943. # such as self.running, self.swapped, and self.waiting.
  944. scheduler_outputs = self._schedule()
  945. now = time.time()
  946. if not self.cache_config.enable_prefix_caching:
  947. common_computed_block_nums = []
  948. allow_async_output_proc: bool = self.use_async_output_proc
  949. # Create input data structures.
  950. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  951. for i, scheduled_seq_group in enumerate(
  952. scheduler_outputs.scheduled_seq_groups):
  953. seq_group = scheduled_seq_group.seq_group
  954. token_chunk_size = scheduled_seq_group.token_chunk_size
  955. seq_group.maybe_set_first_scheduled_time(now)
  956. seq_group_metadata = self._seq_group_metadata_cache[
  957. self.cache_id].get_object()
  958. seq_group_metadata.seq_data.clear()
  959. seq_group_metadata.block_tables.clear()
  960. # seq_id -> SequenceData
  961. seq_data: Dict[int, SequenceData] = {}
  962. # seq_id -> physical block numbers
  963. block_tables: Dict[int, List[int]] = {}
  964. if seq_group.is_encoder_decoder():
  965. # Encoder associated with SequenceGroup
  966. encoder_seq_data = seq_group.get_encoder_seq().data
  967. # Block table for cross-attention
  968. # Also managed at SequenceGroup level
  969. cross_block_table = self.block_manager.get_cross_block_table(
  970. seq_group)
  971. else:
  972. encoder_seq_data = None
  973. cross_block_table = None
  974. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  975. seq_id = seq.seq_id
  976. seq_data[seq_id] = seq.data
  977. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  978. self.block_manager.access_all_blocks_in_seq(seq, now)
  979. if self.cache_config.enable_prefix_caching:
  980. common_computed_block_nums = (
  981. self.block_manager.get_common_computed_block_ids(
  982. seq_group.get_seqs(status=SequenceStatus.RUNNING)))
  983. do_sample = True
  984. is_prompt = seq_group.is_prefill()
  985. # We should send the metadata to workers when the first prefill
  986. # is sent. Subsequent requests could be chunked prefill or decode.
  987. is_first_prefill = False
  988. if is_prompt:
  989. seqs = seq_group.get_seqs()
  990. # Prefill has only 1 sequence.
  991. assert len(seqs) == 1
  992. num_computed_tokens = seqs[0].data.get_num_computed_tokens()
  993. is_first_prefill = num_computed_tokens == 0
  994. # In the next iteration, all prompt tokens are not computed.
  995. # It means the prefill is chunked, and we don't need sampling.
  996. # NOTE: We use get_len instead of get_prompt_len because when
  997. # a sequence is preempted, prefill includes previous generated
  998. # output tokens.
  999. if (token_chunk_size + num_computed_tokens <
  1000. seqs[0].data.get_len()):
  1001. do_sample = False
  1002. # It assumes the scheduled_seq_groups is ordered by
  1003. # prefill < decoding.
  1004. if is_first_prefill or not self.scheduler_config.send_delta_data:
  1005. seq_group_metadata = SequenceGroupMetadata(
  1006. request_id=seq_group.request_id,
  1007. is_prompt=is_prompt,
  1008. seq_data=seq_data,
  1009. sampling_params=seq_group.sampling_params,
  1010. block_tables=block_tables,
  1011. do_sample=do_sample,
  1012. pooling_params=seq_group.pooling_params,
  1013. token_chunk_size=token_chunk_size,
  1014. lora_request=seq_group.lora_request,
  1015. computed_block_nums=common_computed_block_nums,
  1016. encoder_seq_data=encoder_seq_data,
  1017. cross_block_table=cross_block_table,
  1018. state=seq_group.state,
  1019. # `multi_modal_data` will only be present for the 1st comm
  1020. # between engine and worker.
  1021. # the subsequent comms can still use delta, but
  1022. # `multi_modal_data` will be None.
  1023. multi_modal_data=seq_group.multi_modal_data
  1024. if scheduler_outputs.num_prefill_groups > 0 else None,
  1025. prompt_adapter_request=seq_group.prompt_adapter_request,
  1026. )
  1027. else:
  1028. # When SPMD mode is enabled, we only send delta data except for
  1029. # the first request to reduce serialization cost.
  1030. seq_data_delta = {}
  1031. for id, data in seq_data.items():
  1032. seq_data_delta[id] = data.get_delta_and_reset()
  1033. seq_group_metadata = SequenceGroupMetadataDelta(
  1034. seq_data_delta,
  1035. seq_group.request_id,
  1036. block_tables,
  1037. is_prompt,
  1038. do_sample=do_sample,
  1039. token_chunk_size=token_chunk_size,
  1040. computed_block_nums=common_computed_block_nums,
  1041. )
  1042. seq_group_metadata_list.append(seq_group_metadata)
  1043. if allow_async_output_proc:
  1044. allow_async_output_proc = self._allow_async_output_proc(
  1045. seq_group)
  1046. # Now that the batch has been created, we can assume all blocks in the
  1047. # batch will have been computed before the next scheduling invocation.
  1048. # This is because the engine assumes that a failure in model execution
  1049. # will crash the Aphrodite instance / will not retry.
  1050. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
  1051. self.block_manager.mark_blocks_as_computed(
  1052. scheduled_seq_group.seq_group,
  1053. scheduled_seq_group.token_chunk_size)
  1054. self._seq_group_metadata_cache[self.next_cache_id].reset()
  1055. # Move to next cache (if exists)
  1056. self.cache_id = self.next_cache_id
  1057. # Return results
  1058. return (seq_group_metadata_list, scheduler_outputs,
  1059. allow_async_output_proc)
  1060. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  1061. self.block_manager.fork(parent_seq, child_seq)
  1062. def free_seq(self, seq: Sequence) -> None:
  1063. """Free a sequence from a block table."""
  1064. self.block_manager.free(seq)
  1065. def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
  1066. """Free finished seqs in a sequence group."""
  1067. for seq in seq_group.get_seqs():
  1068. if seq.is_finished():
  1069. self.free_seq(seq)
  1070. def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
  1071. if seq_group.is_finished():
  1072. # Free cross-attention block table, if it exists
  1073. self._free_seq_group_cross_attn_blocks(seq_group)
  1074. # Add the finished requests to the finished requests list.
  1075. # This list will be used to update the Mamba cache in the
  1076. # next step.
  1077. self._finished_requests_ids.append(seq_group.request_id)
  1078. # Free finished seqs
  1079. self._free_finished_seqs(seq_group)
  1080. def free_finished_seq_groups(self) -> None:
  1081. remaining: Deque[SequenceGroup] = deque()
  1082. for seq_group in self.running:
  1083. self._free_finished_seq_group(seq_group)
  1084. if not seq_group.is_finished():
  1085. remaining.append(seq_group)
  1086. self.running = remaining
  1087. # Handle async stopped sequence groups
  1088. # (ones that reached max model len)
  1089. if self._async_stopped:
  1090. for seq_group in self._async_stopped:
  1091. self._free_seq_group_cross_attn_blocks(seq_group)
  1092. self._finished_requests_ids.append(seq_group.request_id)
  1093. # Free finished seqs
  1094. self._free_finished_seqs(seq_group)
  1095. self._async_stopped.clear()
  1096. def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
  1097. self.block_manager.allocate(seq_group)
  1098. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  1099. seq.status = SequenceStatus.RUNNING
  1100. def _append_slots(
  1101. self,
  1102. seq_group: SequenceGroup,
  1103. blocks_to_copy: List[Tuple[int, int]],
  1104. ) -> None:
  1105. """Appends new slots to the sequences in the given sequence group.
  1106. Args:
  1107. seq_group (SequenceGroup): The sequence group containing the
  1108. sequences to append slots to.
  1109. blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
  1110. ints, the first int is the source block index, and the second
  1111. int is the destination block index. This list is updated with
  1112. the new source and destination block indices for the appended
  1113. slots.
  1114. """
  1115. num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
  1116. seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
  1117. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  1118. cows = self.block_manager.append_slots(seq, num_lookahead_slots)
  1119. if len(cows) > 0:
  1120. blocks_to_copy.extend(cows)
  1121. def _preempt(
  1122. self,
  1123. seq_group: SequenceGroup,
  1124. blocks_to_swap_out: List[Tuple[int, int]],
  1125. preemption_mode: Optional[PreemptionMode] = None,
  1126. ) -> PreemptionMode:
  1127. # If preemption mode is not specified, we determine the mode as follows:
  1128. # We use recomputation by default since it incurs lower overhead than
  1129. # swapping. However, when the sequence group has multiple sequences
  1130. # (e.g., beam search), recomputation is not currently supported. In
  1131. # such a case, we use swapping instead.
  1132. # FIXME: This makes our scheduling policy a bit bizarre.
  1133. # As swapped sequences are prioritized over waiting sequences,
  1134. # sequence groups with multiple sequences are implicitly prioritized
  1135. # over sequence groups with a single sequence.
  1136. # TODO: Support recomputation for sequence groups with multiple
  1137. # sequences. This may require a more sophisticated CUDA kernel.
  1138. if self.user_specified_preemption_mode is None:
  1139. if seq_group.get_max_num_running_seqs() == 1:
  1140. preemption_mode = PreemptionMode.RECOMPUTE
  1141. else:
  1142. preemption_mode = PreemptionMode.SWAP
  1143. elif self.user_specified_preemption_mode == "swap":
  1144. preemption_mode = PreemptionMode.SWAP
  1145. else:
  1146. preemption_mode = PreemptionMode.RECOMPUTE
  1147. if self.num_cumulative_preemption % 50 == 0:
  1148. logger.warning(
  1149. f"Sequence group {seq_group.request_id} is preempted by "
  1150. f"{preemption_mode} mode because there is "
  1151. "not enough KV cache space. This can affect the end-to-end "
  1152. "performance. Increase gpu_memory_utilization or "
  1153. "tensor_parallel_size to provide more KV cache memory. "
  1154. "total_num_cumulative_preemption="
  1155. f"{self.num_cumulative_preemption + 1}")
  1156. self.num_cumulative_preemption += 1
  1157. if preemption_mode == PreemptionMode.RECOMPUTE:
  1158. self._preempt_by_recompute(seq_group)
  1159. elif preemption_mode == PreemptionMode.SWAP:
  1160. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  1161. else:
  1162. raise AssertionError("Invalid preemption mode.")
  1163. return preemption_mode
  1164. def _preempt_by_recompute(
  1165. self,
  1166. seq_group: SequenceGroup,
  1167. ) -> None:
  1168. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  1169. assert len(seqs) == 1
  1170. for seq in seqs:
  1171. seq.status = SequenceStatus.WAITING
  1172. self.free_seq(seq)
  1173. seq.reset_state_for_recompute()
  1174. def _preempt_by_swap(
  1175. self,
  1176. seq_group: SequenceGroup,
  1177. blocks_to_swap_out: List[Tuple[int, int]],
  1178. ) -> None:
  1179. self._swap_out(seq_group, blocks_to_swap_out)
  1180. def _swap_in(
  1181. self,
  1182. seq_group: SequenceGroup,
  1183. blocks_to_swap_in: List[Tuple[int, int]],
  1184. ) -> None:
  1185. mapping = self.block_manager.swap_in(seq_group)
  1186. blocks_to_swap_in.extend(mapping)
  1187. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  1188. seq.status = SequenceStatus.RUNNING
  1189. def _swap_out(
  1190. self,
  1191. seq_group: SequenceGroup,
  1192. blocks_to_swap_out: List[Tuple[int, int]],
  1193. ) -> None:
  1194. if not self.block_manager.can_swap_out(seq_group):
  1195. # FIXME: Abort the sequence group instead of aborting the
  1196. # entire engine.
  1197. raise RuntimeError(
  1198. "Aborted due to the lack of CPU swap space. Please increase "
  1199. "the swap space to avoid this error.")
  1200. mapping = self.block_manager.swap_out(seq_group)
  1201. blocks_to_swap_out.extend(mapping)
  1202. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  1203. seq.status = SequenceStatus.SWAPPED
  1204. def _passed_delay(self, now: float) -> bool:
  1205. if self.prev_prompt:
  1206. self.last_prompt_latency = now - self.prev_time
  1207. self.prev_time, self.prev_prompt = now, False
  1208. # Delay scheduling prompts to let waiting queue fill up
  1209. if self.scheduler_config.delay_factor > 0 and self.waiting:
  1210. earliest_arrival_time = min(
  1211. [e.metrics.arrival_time for e in self.waiting])
  1212. passed_delay = (
  1213. (now - earliest_arrival_time) >
  1214. (self.scheduler_config.delay_factor * self.last_prompt_latency)
  1215. or not self.running)
  1216. else:
  1217. passed_delay = True
  1218. return passed_delay
  1219. def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
  1220. """The number of slots to allocate per sequence per step, beyond known
  1221. token ids. Speculative decoding uses these slots to store KV activations
  1222. of tokens which may or may not be accepted.
  1223. Speculative decoding does not yet support prefill, so we do not perform
  1224. lookahead allocation for prefill.
  1225. """
  1226. if is_prefill:
  1227. return 0
  1228. return self.scheduler_config.num_lookahead_slots
  1229. def _get_num_new_tokens(self, seq_group: SequenceGroup,
  1230. status: SequenceStatus, enable_chunking: bool,
  1231. budget: SchedulingBudget) -> int:
  1232. """Get the next new tokens to compute for a given sequence group
  1233. that's in a given `status`.
  1234. The API could chunk the number of tokens to compute based on `budget`
  1235. if `enable_chunking` is True. If a sequence group has multiple
  1236. sequences (e.g., running beam search), it means it is in decoding
  1237. phase, so chunking doesn't happen.
  1238. Returns 0 if the new token cannot be computed due to token budget.
  1239. """
  1240. num_new_tokens = 0
  1241. seqs = seq_group.get_seqs(status=status)
  1242. for seq in seqs:
  1243. num_new_tokens += seq.get_num_new_tokens()
  1244. assert num_new_tokens > 0
  1245. # Chunk if a running request cannot fit in the given budget.
  1246. # If number of seq > 1, it means it is doing beam search
  1247. # in a decode phase. Do not chunk.
  1248. if enable_chunking and len(seqs) == 1:
  1249. remaining_token_budget = budget.remaining_token_budget()
  1250. if self.cache_config.enable_prefix_caching:
  1251. # When prefix caching is enabled, we always allocate
  1252. # the number of new tokens that is dividable by the block size
  1253. # to avoid partial block matching.
  1254. block_size = self.cache_config.block_size
  1255. reminder = budget.token_budget % block_size
  1256. if reminder != 0:
  1257. raise ValueError("When enabling chunked prefill and "
  1258. "prefix caching, max_num_batched_tokens "
  1259. "(chunk size) must be dividable by "
  1260. "block size, but got chunk_size "
  1261. f"({budget.token_budget}) % block_size "
  1262. f"({block_size}) = {reminder}")
  1263. if remaining_token_budget < num_new_tokens:
  1264. num_new_tokens = (remaining_token_budget //
  1265. block_size) * block_size
  1266. else:
  1267. num_new_tokens = min(num_new_tokens, remaining_token_budget)
  1268. return num_new_tokens