scheduler.py 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372
  1. import enum
  2. import os
  3. import random
  4. import time
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
  8. from loguru import logger
  9. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  10. from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
  11. SequenceGroupMetadata,
  12. SequenceGroupMetadataDelta,
  13. SequenceStatus)
  14. from aphrodite.common.utils import Device, PyObjectCache
  15. from aphrodite.lora.request import LoRARequest
  16. from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
  17. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  18. # Test-only. If configured, decode is preempted with
  19. # ARTIFICIAL_PREEMPTION_PROB% probability.
  20. ENABLE_ARTIFICIAL_PREEMPT = bool(
  21. os.getenv("APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
  22. ARTIFICIAL_PREEMPTION_PROB = 0.5
  23. ARTIFICIAL_PREEMPTION_MAX_CNT = 500
  24. class PreemptionMode(enum.Enum):
  25. """Preemption modes.
  26. 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
  27. and swap them back in when the sequences are resumed.
  28. 2. Recomputation: Discard the blocks of the preempted sequences and
  29. recompute them when the sequences are resumed, treating the sequences as
  30. new prompts.
  31. """
  32. SWAP = enum.auto()
  33. RECOMPUTE = enum.auto()
  34. @dataclass
  35. class SchedulingBudget:
  36. """The available slots for scheduling.
  37. TODO: Right now, the budget is request_id-aware meaning it can ignore
  38. budget update from the same request_id. It is because in normal scheduling
  39. path, we update RUNNING num_seqs ahead of time, meaning it could be
  40. updated more than once when scheduling RUNNING requests. Since this won't
  41. happen if we only have chunked prefill scheduling, we can remove this
  42. feature from the API when chunked prefill is enabled by default.
  43. """
  44. token_budget: int
  45. max_num_seqs: int
  46. _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
  47. _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
  48. _num_batched_tokens: int = 0
  49. _num_curr_seqs: int = 0
  50. def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
  51. assert num_new_tokens != 0
  52. assert num_new_seqs != 0
  53. return (self.num_batched_tokens + num_new_tokens <= self.token_budget
  54. and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
  55. def remaining_token_budget(self):
  56. return self.token_budget - self.num_batched_tokens
  57. def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
  58. if req_id in self._request_ids_num_batched_tokens:
  59. return
  60. self._request_ids_num_batched_tokens.add(req_id)
  61. self._num_batched_tokens += num_batched_tokens
  62. def subtract_num_batched_tokens(self, req_id: str,
  63. num_batched_tokens: int):
  64. if req_id in self._request_ids_num_batched_tokens:
  65. self._request_ids_num_batched_tokens.remove(req_id)
  66. self._num_batched_tokens -= num_batched_tokens
  67. def add_num_seqs(self, req_id: str, num_curr_seqs: int):
  68. if req_id in self._request_ids_num_curr_seqs:
  69. return
  70. self._request_ids_num_curr_seqs.add(req_id)
  71. self._num_curr_seqs += num_curr_seqs
  72. def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
  73. if req_id in self._request_ids_num_curr_seqs:
  74. self._request_ids_num_curr_seqs.remove(req_id)
  75. self._num_curr_seqs -= num_curr_seqs
  76. @property
  77. def num_batched_tokens(self):
  78. return self._num_batched_tokens
  79. @property
  80. def num_curr_seqs(self):
  81. return self._num_curr_seqs
  82. @dataclass
  83. class ScheduledSequenceGroup:
  84. # A sequence group that's scheduled.
  85. seq_group: SequenceGroup
  86. # The total chunk size (number of tokens) to process for next iteration.
  87. # 1 for decoding. Same as prompt tokens for prefill, but if prefill is
  88. # chunked, it can be smaller than that.
  89. token_chunk_size: int
  90. @dataclass
  91. class SchedulerOutputs:
  92. """The scheduling decision made from a scheduler."""
  93. # Scheduled sequence groups.
  94. scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
  95. # Number of prefill groups scheduled.
  96. num_prefill_groups: int
  97. # Total number of batched tokens.
  98. num_batched_tokens: int
  99. # Blocks to swap in. List of CPU -> GPU block number.
  100. blocks_to_swap_in: List[Tuple[int, int]]
  101. # Blocks to swap out. List of GPU -> CPU block number.
  102. blocks_to_swap_out: List[Tuple[int, int]]
  103. # Blocks to copy. Source to dest block.
  104. blocks_to_copy: List[Tuple[int, int]]
  105. # Sequence groups that are going to be ignored.
  106. ignored_seq_groups: List[SequenceGroup]
  107. # The number of slots for lookahead decoding.
  108. num_lookahead_slots: int
  109. # The number of requests in the running queue
  110. running_queue_size: int
  111. preempted: int
  112. def __post_init__(self):
  113. # Swap in and swap out should never happen at the same time.
  114. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
  115. self.num_loras: int = len(self.lora_requests)
  116. if self.num_loras > 0:
  117. self._sort_by_lora_ids()
  118. self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
  119. def is_empty(self) -> bool:
  120. # NOTE: We do not consider the ignored sequence groups.
  121. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
  122. and not self.blocks_to_swap_out and not self.blocks_to_copy)
  123. def _sort_by_lora_ids(self):
  124. assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)
  125. def key_fn(group: ScheduledSequenceGroup):
  126. key = (group.seq_group.lora_int_id, group.seq_group.request_id)
  127. if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
  128. # Sort sequence groups so that all prefills come before all
  129. # decodes as required by chunked prefill.
  130. return (not group.seq_group.is_prefill(), *key)
  131. return key
  132. self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
  133. key=key_fn)
  134. @property
  135. def lora_requests(self) -> Set[LoRARequest]:
  136. return {
  137. g.seq_group.lora_request
  138. for g in self.scheduled_seq_groups
  139. if g.seq_group.lora_request is not None
  140. }
  141. @property
  142. def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
  143. return {
  144. g.seq_group.prompt_adapter_request
  145. for g in self.scheduled_seq_groups
  146. if g.seq_group.prompt_adapter_request is not None
  147. }
  148. @dataclass
  149. class SchedulerRunningOutputs:
  150. """The requests that are scheduled from a running queue.
  151. Could contain prefill (prefill that's chunked) or decodes. If there's not
  152. enough memory, it can be preempted (for recompute) or swapped out.
  153. """
  154. # Selected sequences that are running and in a decoding phase.
  155. decode_seq_groups: List[ScheduledSequenceGroup]
  156. # Selected sequences that are running and in a prefill phase.
  157. # I.e., it means the prefill has been chunked.
  158. prefill_seq_groups: List[ScheduledSequenceGroup]
  159. # The preempted sequences.
  160. preempted: List[SequenceGroup]
  161. # Sequences that are swapped out.
  162. swapped_out: List[SequenceGroup]
  163. # The blocks to swap out.
  164. blocks_to_swap_out: List[Tuple[int, int]]
  165. # The blocks to copy.
  166. blocks_to_copy: List[Tuple[int, int]]
  167. # The number of slots for lookahead decoding.
  168. num_lookahead_slots: int
  169. # Optimization for fast-access to seq_group lists
  170. decode_seq_groups_list: List[SequenceGroup]
  171. prefill_seq_groups_list: List[SequenceGroup]
  172. @classmethod
  173. def create_empty(cls) -> "SchedulerRunningOutputs":
  174. return SchedulerRunningOutputs(
  175. decode_seq_groups=[],
  176. prefill_seq_groups=[],
  177. preempted=[],
  178. swapped_out=[],
  179. blocks_to_swap_out=[],
  180. blocks_to_copy=[],
  181. num_lookahead_slots=0,
  182. decode_seq_groups_list=[],
  183. prefill_seq_groups_list=[],
  184. )
  185. @dataclass
  186. class SchedulerSwappedInOutputs:
  187. """The requests that are scheduled from a swap queue.
  188. Could contain prefill (prefill that's chunked) or decodes.
  189. """
  190. # Selected sequences that are going to be swapped in and is in a
  191. # decoding phase.
  192. decode_seq_groups: List[SequenceGroup]
  193. # Selected sequences that are going to be swapped in and in a prefill
  194. # phase. I.e., it means the prefill has been chunked.
  195. prefill_seq_groups: List[SequenceGroup]
  196. # The blocks to swap in.
  197. blocks_to_swap_in: List[Tuple[int, int]]
  198. # The blocks to copy.
  199. blocks_to_copy: List[Tuple[int, int]]
  200. # The number of slots for lookahead decoding.
  201. num_lookahead_slots: int
  202. # Infeasible sequence groups.
  203. infeasible_seq_groups: List[SequenceGroup]
  204. @classmethod
  205. def create_empty(cls) -> "SchedulerSwappedInOutputs":
  206. return SchedulerSwappedInOutputs(
  207. decode_seq_groups=[],
  208. prefill_seq_groups=[],
  209. blocks_to_swap_in=[],
  210. blocks_to_copy=[],
  211. num_lookahead_slots=0,
  212. infeasible_seq_groups=[],
  213. )
  214. @dataclass
  215. class SchedulerPrefillOutputs:
  216. """The requests that are scheduled from a waiting queue.
  217. Could contain a fresh prefill requests or preempted requests that need
  218. to be recomputed from scratch.
  219. """
  220. # Selected sequences for prefill.
  221. seq_groups: List[SequenceGroup]
  222. # Ignored sequence groups.
  223. ignored_seq_groups: List[SequenceGroup]
  224. num_lookahead_slots: int
  225. @classmethod
  226. def create_empty(cls) -> "SchedulerPrefillOutputs":
  227. return SchedulerPrefillOutputs(
  228. seq_groups=[],
  229. ignored_seq_groups=[],
  230. num_lookahead_slots=0,
  231. )
  232. def seq_group_metadata_builder():
  233. return SequenceGroupMetadata(request_id="",
  234. is_prompt=False,
  235. seq_data={},
  236. sampling_params=None,
  237. block_tables={})
  238. def scheduler_running_outputs_builder():
  239. return SchedulerRunningOutputs(decode_seq_groups=[],
  240. prefill_seq_groups=[],
  241. preempted=[],
  242. swapped_out=[],
  243. blocks_to_swap_out=[],
  244. blocks_to_copy=[],
  245. num_lookahead_slots=0,
  246. prefill_seq_groups_list=[],
  247. decode_seq_groups_list=[])
  248. def scheduled_seq_group_builder():
  249. return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
  250. class Scheduler:
  251. def __init__(
  252. self,
  253. scheduler_config: SchedulerConfig,
  254. cache_config: CacheConfig,
  255. lora_config: Optional[LoRAConfig],
  256. pipeline_parallel_size: int = 1,
  257. ) -> None:
  258. self.scheduler_config = scheduler_config
  259. self.cache_config = cache_config
  260. # Note for LoRA scheduling: the current policy is extremely
  261. # simple and NOT fair. It can lead to starvation of some
  262. # LoRAs. This should be improved in the future.
  263. self.lora_config = lora_config
  264. version = "v1"
  265. if self.scheduler_config.use_v2_block_manager:
  266. version = "v2"
  267. if (self.scheduler_config.embedding_mode
  268. or self.scheduler_config.is_attention_free):
  269. version = "placeholder"
  270. BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
  271. version)
  272. num_gpu_blocks = cache_config.num_gpu_blocks
  273. if num_gpu_blocks:
  274. num_gpu_blocks //= pipeline_parallel_size
  275. num_cpu_blocks = cache_config.num_cpu_blocks
  276. if num_cpu_blocks:
  277. num_cpu_blocks //= pipeline_parallel_size
  278. # Create the block space manager.
  279. self.block_manager = BlockSpaceManagerImpl(
  280. block_size=self.cache_config.block_size,
  281. num_gpu_blocks=num_gpu_blocks,
  282. num_cpu_blocks=num_cpu_blocks,
  283. sliding_window=self.cache_config.sliding_window,
  284. enable_caching=self.cache_config.enable_prefix_caching)
  285. # Sequence groups in the WAITING state.
  286. # Contain new prefill or preempted requests.
  287. self.waiting: Deque[SequenceGroup] = deque()
  288. # Sequence groups in the RUNNING state.
  289. # Contain decode requests.
  290. self.running: Deque[SequenceGroup] = deque()
  291. # Sequence groups in the SWAPPED state.
  292. # Contain decode requests that are swapped out.
  293. self.swapped: Deque[SequenceGroup] = deque()
  294. # Sequence groups finished requests ids since last step iteration.
  295. # It lets the model know that any state associated with these requests
  296. # can and must be released after the current step.
  297. # This is used to evict the finished requests from the Mamba cache.
  298. self._finished_requests_ids: List[str] = list()
  299. # Time at previous scheduling step
  300. self.prev_time = 0.0
  301. # Did we schedule a prompt at previous step?
  302. self.prev_prompt = False
  303. # Latency of the last prompt step
  304. self.last_prompt_latency = 0.0
  305. # preemption mode, RECOMPUTE or SWAP
  306. self.user_specified_preemption_mode = scheduler_config.preemption_mode
  307. # The following field is test-only. It is used to inject artificial
  308. # preemption.
  309. self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
  310. self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
  311. if self.enable_artificial_preemption
  312. else 0)
  313. self.num_cumulative_preemption: int = 0
  314. # Used to cache python objects
  315. self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
  316. scheduler_running_outputs_builder)
  317. self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
  318. scheduled_seq_group_builder)
  319. @property
  320. def lora_enabled(self) -> bool:
  321. return bool(self.lora_config)
  322. @property
  323. def num_decoding_tokens_per_seq(self) -> int:
  324. """The number of new tokens."""
  325. return 1
  326. def add_seq_group(self, seq_group: SequenceGroup) -> None:
  327. # Add sequence groups to the waiting queue.
  328. self.waiting.append(seq_group)
  329. def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
  330. # Add sequence groups to the running queue.
  331. # Only for testing purposes.
  332. self.running.append(seq_group)
  333. def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
  334. # Add sequence groups to the swapped queue.
  335. # Only for testing purposes.
  336. self.swapped.append(seq_group)
  337. def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
  338. """Aborts a sequence group with the given ID.
  339. Check if the sequence group with the given ID
  340. is present in any of the state queue.
  341. If present, remove the sequence group from the state queue.
  342. Also, if any of the sequences in the sequence group is not finished,
  343. free the sequence with status `FINISHED_ABORTED`.
  344. Otherwise, do nothing.
  345. Args:
  346. request_id: The ID(s) of the sequence group to abort.
  347. """
  348. if isinstance(request_id, str):
  349. request_id = (request_id, )
  350. request_ids = set(request_id)
  351. for state_queue in [self.waiting, self.running, self.swapped]:
  352. aborted_groups: List[SequenceGroup] = []
  353. for seq_group in state_queue:
  354. if not request_ids:
  355. # Using 'break' here may add two extra iterations,
  356. # but is acceptable to reduce complexity.
  357. break
  358. if seq_group.request_id in request_ids:
  359. # Appending aborted group into pending list.
  360. aborted_groups.append(seq_group)
  361. request_ids.remove(seq_group.request_id)
  362. for aborted_group in aborted_groups:
  363. # Remove the sequence group from the state queue.
  364. state_queue.remove(aborted_group)
  365. # Remove the aborted request from the Mamba cache.
  366. self._finished_requests_ids.append(aborted_group.request_id)
  367. for seq in aborted_group.get_seqs():
  368. if seq.is_finished():
  369. continue
  370. seq.status = SequenceStatus.FINISHED_ABORTED
  371. self.free_seq(seq)
  372. self._free_seq_group_cross_attn_blocks(aborted_group)
  373. def _free_seq_group_cross_attn_blocks(
  374. self,
  375. seq_group: SequenceGroup,
  376. ) -> None:
  377. """
  378. Free a sequence group from a cross-attention block table.
  379. Has no effect on decoder-only models.
  380. """
  381. if seq_group.is_encoder_decoder():
  382. self.block_manager.free_cross(seq_group)
  383. def has_unfinished_seqs(self) -> bool:
  384. return len(self.waiting) != 0 or len(self.running) != 0 or len(
  385. self.swapped) != 0
  386. def get_prefix_cache_hit_rate(self, device: Device) -> float:
  387. return self.block_manager.get_prefix_cache_hit_rate(device)
  388. def get_num_unfinished_seq_groups(self) -> int:
  389. return len(self.waiting) + len(self.running) + len(self.swapped)
  390. def get_and_reset_finished_requests_ids(self) -> List[str]:
  391. """Flushes the list of request ids of previously finished seq_groups."""
  392. finished_requests_ids = self._finished_requests_ids
  393. self._finished_requests_ids = list()
  394. return finished_requests_ids
  395. def _schedule_running(
  396. self,
  397. budget: SchedulingBudget,
  398. curr_loras: Optional[Set[int]],
  399. enable_chunking: bool = False,
  400. ) -> SchedulerRunningOutputs:
  401. """Schedule sequence groups that are running.
  402. Running queue should include decode and chunked prefill requests.
  403. Args:
  404. budget: The scheduling budget. The argument is in-place updated
  405. when any decodes are preempted.
  406. curr_loras: Currently batched lora request ids. The argument is
  407. in-place updated when any decodes are preempted.
  408. enable_chunking: If True, seq group can be chunked and only a
  409. chunked number of tokens are scheduled if
  410. `budget.num_batched_tokens` has not enough capacity to schedule
  411. all tokens.
  412. Returns:
  413. SchedulerRunningOutputs.
  414. """
  415. ret: SchedulerRunningOutputs = \
  416. self._scheduler_running_outputs_cache.get_object()
  417. ret.blocks_to_swap_out.clear()
  418. ret.blocks_to_copy.clear()
  419. ret.decode_seq_groups.clear()
  420. ret.prefill_seq_groups.clear()
  421. ret.preempted.clear()
  422. ret.swapped_out.clear()
  423. ret.num_lookahead_slots = self._get_num_lookahead_slots(
  424. is_prefill=False)
  425. ret.decode_seq_groups_list.clear()
  426. ret.prefill_seq_groups_list.clear()
  427. # Blocks that need to be swapped or copied before model execution.
  428. blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
  429. blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
  430. decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
  431. prefill_seq_groups: List[
  432. ScheduledSequenceGroup] = ret.prefill_seq_groups
  433. preempted: List[SequenceGroup] = ret.preempted
  434. swapped_out: List[SequenceGroup] = ret.swapped_out
  435. # NOTE: Preemption happens only when there is no available slot
  436. # to keep all the sequence groups in the RUNNING state.
  437. running_queue = self.running
  438. while running_queue:
  439. seq_group = running_queue[0]
  440. num_running_tokens = self._get_num_new_tokens(
  441. seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
  442. if num_running_tokens == 0:
  443. break
  444. running_queue.popleft()
  445. while not self._can_append_slots(seq_group):
  446. budget.subtract_num_batched_tokens(seq_group.request_id,
  447. num_running_tokens)
  448. num_running_seqs = seq_group.get_max_num_running_seqs()
  449. budget.subtract_num_seqs(seq_group.request_id,
  450. num_running_seqs)
  451. if (curr_loras is not None and seq_group.lora_int_id > 0
  452. and seq_group.lora_int_id in curr_loras):
  453. curr_loras.remove(seq_group.lora_int_id)
  454. if running_queue:
  455. # Preempt the lowest-priority sequence groups.
  456. victim_seq_group = running_queue.pop()
  457. preempted_mode = self._preempt(victim_seq_group,
  458. blocks_to_swap_out)
  459. if preempted_mode == PreemptionMode.RECOMPUTE:
  460. preempted.append(victim_seq_group)
  461. else:
  462. swapped_out.append(victim_seq_group)
  463. else:
  464. # No other sequence groups can be preempted.
  465. # Preempt the current sequence group.
  466. preempted_mode = self._preempt(seq_group,
  467. blocks_to_swap_out)
  468. if preempted_mode == PreemptionMode.RECOMPUTE:
  469. preempted.append(seq_group)
  470. else:
  471. swapped_out.append(seq_group)
  472. break
  473. else:
  474. self._append_slots(seq_group, blocks_to_copy)
  475. is_prefill = seq_group.is_prefill()
  476. scheduled_seq_group: ScheduledSequenceGroup = \
  477. self._scheduled_seq_group_cache.get_object()
  478. scheduled_seq_group.seq_group = seq_group
  479. if is_prefill:
  480. scheduled_seq_group.token_chunk_size = num_running_tokens
  481. prefill_seq_groups.append(scheduled_seq_group)
  482. ret.prefill_seq_groups_list.append(seq_group)
  483. else:
  484. scheduled_seq_group.token_chunk_size = 1
  485. decode_seq_groups.append(scheduled_seq_group)
  486. ret.decode_seq_groups_list.append(seq_group)
  487. budget.add_num_batched_tokens(seq_group.request_id,
  488. num_running_tokens)
  489. # OPTIMIZATION: Note that get_max_num_running_seqs is
  490. # expensive. For the default scheduling chase where
  491. # enable_chunking is False, num_seqs are updated before running
  492. # this method, so we don't have to update it again here.
  493. if enable_chunking:
  494. num_running_seqs = seq_group.get_max_num_running_seqs()
  495. budget.add_num_seqs(seq_group.request_id, num_running_seqs)
  496. if curr_loras is not None and seq_group.lora_int_id > 0:
  497. curr_loras.add(seq_group.lora_int_id)
  498. self._scheduler_running_outputs_cache.reset()
  499. self._scheduled_seq_group_cache.reset()
  500. return ret
  501. def _schedule_swapped(
  502. self,
  503. budget: SchedulingBudget,
  504. curr_loras: Optional[Set[int]],
  505. enable_chunking: bool = False,
  506. ) -> SchedulerSwappedInOutputs:
  507. """Schedule sequence groups that are swapped out.
  508. It schedules swapped requests as long as it fits `budget` and
  509. curr_loras <= max_lora from the scheduling config. The input arguments
  510. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  511. Args:
  512. budget: The scheduling budget. The argument is in-place updated
  513. when any requests are swapped in.
  514. curr_loras: Currently batched lora request ids. The argument is
  515. in-place updated when any requests are swapped in.
  516. enable_chunking: If True, seq group can be chunked and only a
  517. chunked number of tokens are scheduled if
  518. `budget.num_batched_tokens` has not enough capacity to schedule
  519. all tokens.
  520. Returns:
  521. SchedulerSwappedInOutputs.
  522. """
  523. # Blocks that need to be swapped or copied before model execution.
  524. blocks_to_swap_in: List[Tuple[int, int]] = []
  525. blocks_to_copy: List[Tuple[int, int]] = []
  526. decode_seq_groups: List[ScheduledSequenceGroup] = []
  527. prefill_seq_groups: List[ScheduledSequenceGroup] = []
  528. infeasible_seq_groups: List[SequenceGroup] = []
  529. swapped_queue = self.swapped
  530. leftover_swapped: Deque[SequenceGroup] = deque()
  531. while swapped_queue:
  532. seq_group = swapped_queue[0]
  533. # If the sequence group cannot be swapped in, stop.
  534. is_prefill = seq_group.is_prefill()
  535. alloc_status = self.block_manager.can_swap_in(
  536. seq_group, self._get_num_lookahead_slots(is_prefill))
  537. if alloc_status == AllocStatus.LATER:
  538. break
  539. elif alloc_status == AllocStatus.NEVER:
  540. logger.warning(f"Failing the request {seq_group.request_id} "
  541. "because there's not enough kv cache blocks to "
  542. "run the entire sequence.")
  543. for seq in seq_group.get_seqs():
  544. seq.status = SequenceStatus.FINISHED_IGNORED
  545. infeasible_seq_groups.append(seq_group)
  546. swapped_queue.popleft()
  547. continue
  548. lora_int_id = 0
  549. if self.lora_enabled:
  550. lora_int_id = seq_group.lora_int_id
  551. assert curr_loras is not None
  552. assert self.lora_config is not None
  553. if (lora_int_id > 0 and (lora_int_id not in curr_loras)
  554. and len(curr_loras) >= self.lora_config.max_loras):
  555. # We don't have a space for another LoRA, so
  556. # we ignore this request for now.
  557. leftover_swapped.appendleft(seq_group)
  558. swapped_queue.popleft()
  559. continue
  560. # The total number of sequences in the RUNNING state should not
  561. # exceed the maximum number of sequences.
  562. num_new_seqs = seq_group.get_max_num_running_seqs()
  563. num_new_tokens = self._get_num_new_tokens(seq_group,
  564. SequenceStatus.SWAPPED,
  565. enable_chunking, budget)
  566. if (num_new_tokens == 0
  567. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  568. num_new_seqs=num_new_seqs)):
  569. break
  570. if lora_int_id > 0 and curr_loras is not None:
  571. curr_loras.add(lora_int_id)
  572. swapped_queue.popleft()
  573. self._swap_in(seq_group, blocks_to_swap_in)
  574. self._append_slots(seq_group, blocks_to_copy)
  575. is_prefill = seq_group.is_prefill()
  576. if is_prefill:
  577. prefill_seq_groups.append(
  578. ScheduledSequenceGroup(seq_group,
  579. token_chunk_size=num_new_tokens))
  580. else:
  581. decode_seq_groups.append(
  582. ScheduledSequenceGroup(seq_group, token_chunk_size=1))
  583. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  584. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  585. swapped_queue.extendleft(leftover_swapped)
  586. return SchedulerSwappedInOutputs(
  587. decode_seq_groups=decode_seq_groups,
  588. prefill_seq_groups=prefill_seq_groups,
  589. blocks_to_swap_in=blocks_to_swap_in,
  590. blocks_to_copy=blocks_to_copy,
  591. num_lookahead_slots=self._get_num_lookahead_slots(
  592. is_prefill=False),
  593. infeasible_seq_groups=infeasible_seq_groups,
  594. )
  595. def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
  596. if self.scheduler_config.chunked_prefill_enabled:
  597. prompt_limit = self.scheduler_config.max_model_len
  598. else:
  599. prompt_limit = min(self.scheduler_config.max_model_len,
  600. self.scheduler_config.max_num_batched_tokens)
  601. # Model is fine tuned with long context. Return the fine tuned max_len.
  602. if (seq_group.lora_request
  603. and seq_group.lora_request.long_lora_max_len):
  604. assert prompt_limit <= seq_group.lora_request.long_lora_max_len
  605. return seq_group.lora_request.long_lora_max_len
  606. else:
  607. return prompt_limit
  608. def _schedule_prefills(
  609. self,
  610. budget: SchedulingBudget,
  611. curr_loras: Optional[Set[int]],
  612. enable_chunking: bool = False,
  613. ) -> SchedulerPrefillOutputs:
  614. """Schedule sequence groups that are in prefill stage.
  615. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
  616. as a new prefill (that starts from beginning -> most recently generated
  617. tokens).
  618. It schedules waiting requests as long as it fits `budget` and
  619. curr_loras <= max_lora from the scheduling config. The input arguments
  620. `budget` and `curr_loras` are updated based on scheduled seq_groups.
  621. Args:
  622. budget: The scheduling budget. The argument is in-place updated
  623. when any requests are scheduled.
  624. curr_loras: Currently batched lora request ids. The argument is
  625. in-place updated when any requests are scheduled.
  626. enable_chunking: If True, seq group can be chunked and only a
  627. chunked number of tokens are scheduled if
  628. `budget.num_batched_tokens` has not enough capacity to schedule
  629. all tokens.
  630. Returns:
  631. SchedulerPrefillOutputs.
  632. """
  633. ignored_seq_groups: List[SequenceGroup] = []
  634. seq_groups: List[SequenceGroup] = []
  635. waiting_queue = self.waiting
  636. leftover_waiting_sequences: Deque[SequenceGroup] = deque()
  637. while self._passed_delay(time.time()) and waiting_queue:
  638. seq_group = waiting_queue[0]
  639. waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
  640. assert len(waiting_seqs) == 1, (
  641. "Waiting sequence group should have only one prompt "
  642. "sequence.")
  643. num_new_tokens = self._get_num_new_tokens(seq_group,
  644. SequenceStatus.WAITING,
  645. enable_chunking, budget)
  646. if not enable_chunking:
  647. num_prompt_tokens = waiting_seqs[0].get_len()
  648. assert num_new_tokens == num_prompt_tokens
  649. prompt_limit = self._get_prompt_limit(seq_group)
  650. if num_new_tokens > prompt_limit:
  651. logger.warning(
  652. f"Input prompt ({num_new_tokens} tokens) is too long"
  653. f" and exceeds limit of {prompt_limit}")
  654. for seq in waiting_seqs:
  655. seq.status = SequenceStatus.FINISHED_IGNORED
  656. ignored_seq_groups.append(seq_group)
  657. waiting_queue.popleft()
  658. continue
  659. # If the sequence group cannot be allocated, stop.
  660. can_allocate = self.block_manager.can_allocate(seq_group)
  661. if can_allocate == AllocStatus.LATER:
  662. break
  663. elif can_allocate == AllocStatus.NEVER:
  664. logger.warning(
  665. f"Input prompt ({num_new_tokens} tokens) is too long"
  666. " and exceeds the capacity of block_manager")
  667. for seq in waiting_seqs:
  668. seq.status = SequenceStatus.FINISHED_IGNORED
  669. ignored_seq_groups.append(seq_group)
  670. waiting_queue.popleft()
  671. continue
  672. lora_int_id = 0
  673. if self.lora_enabled:
  674. lora_int_id = seq_group.lora_int_id
  675. assert curr_loras is not None
  676. assert self.lora_config is not None
  677. if (self.lora_enabled and lora_int_id > 0
  678. and lora_int_id not in curr_loras
  679. and len(curr_loras) >= self.lora_config.max_loras):
  680. # We don't have a space for another LoRA, so
  681. # we ignore this request for now.
  682. leftover_waiting_sequences.appendleft(seq_group)
  683. waiting_queue.popleft()
  684. continue
  685. num_new_seqs = seq_group.get_max_num_running_seqs()
  686. if (num_new_tokens == 0
  687. or not budget.can_schedule(num_new_tokens=num_new_tokens,
  688. num_new_seqs=num_new_seqs)):
  689. break
  690. # Can schedule this request.
  691. if curr_loras is not None and lora_int_id > 0:
  692. curr_loras.add(lora_int_id)
  693. waiting_queue.popleft()
  694. self._allocate_and_set_running(seq_group)
  695. seq_group.init_multi_step(
  696. num_scheduler_steps=self._get_num_lookahead_slots(
  697. is_prefill=True) + 1)
  698. seq_groups.append(
  699. ScheduledSequenceGroup(seq_group=seq_group,
  700. token_chunk_size=num_new_tokens))
  701. budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
  702. budget.add_num_seqs(seq_group.request_id, num_new_seqs)
  703. # Queue requests that couldn't be scheduled.
  704. waiting_queue.extendleft(leftover_waiting_sequences)
  705. if len(seq_groups) > 0:
  706. self.prev_prompt = True
  707. return SchedulerPrefillOutputs(
  708. seq_groups=seq_groups,
  709. ignored_seq_groups=ignored_seq_groups,
  710. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
  711. def _schedule_default(self) -> SchedulerOutputs:
  712. """Schedule queued requests.
  713. The current policy is designed to optimize the throughput. First,
  714. it batches as many prefill requests as possible. And it schedules
  715. decodes. If there's a pressure on GPU memory, decode requests can
  716. be swapped or preempted.
  717. """
  718. # Include running requests to the budget.
  719. budget = SchedulingBudget(
  720. token_budget=self.scheduler_config.max_num_batched_tokens,
  721. max_num_seqs=self.scheduler_config.max_num_seqs,
  722. )
  723. # Make sure we include num running seqs before scheduling prefill,
  724. # so that we don't schedule beyond max_num_seqs for prefill.
  725. for seq_group in self.running:
  726. budget.add_num_seqs(seq_group.request_id,
  727. seq_group.get_max_num_running_seqs())
  728. curr_loras = set(
  729. seq_group.lora_int_id for seq_group in self.running
  730. if seq_group.lora_int_id > 0) if self.lora_enabled else None
  731. prefills = SchedulerPrefillOutputs.create_empty()
  732. running_scheduled = SchedulerRunningOutputs.create_empty()
  733. swapped_in = SchedulerSwappedInOutputs.create_empty()
  734. # If any requests are swapped, prioritized swapped requests.
  735. if not self.swapped:
  736. prefills = self._schedule_prefills(budget,
  737. curr_loras,
  738. enable_chunking=False)
  739. # Don't schedule decodes if prefills are scheduled.
  740. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
  741. # only contains decode requests, not chunked prefills.
  742. if len(prefills.seq_groups) == 0:
  743. running_scheduled = self._schedule_running(budget,
  744. curr_loras,
  745. enable_chunking=False)
  746. # If any sequence group is preempted, do not swap in any sequence
  747. # group. because it means there's no slot for new running requests.
  748. if len(running_scheduled.preempted) + len(
  749. running_scheduled.swapped_out) == 0:
  750. swapped_in = self._schedule_swapped(budget, curr_loras)
  751. assert (budget.num_batched_tokens <=
  752. self.scheduler_config.max_num_batched_tokens)
  753. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  754. # Update waiting requests.
  755. self.waiting.extendleft(running_scheduled.preempted)
  756. # Update new running requests.
  757. if len(prefills.seq_groups) > 0:
  758. self.running.extend([s.seq_group for s in prefills.seq_groups])
  759. self.running.extend(running_scheduled.decode_seq_groups_list)
  760. if len(swapped_in.decode_seq_groups) > 0:
  761. self.running.extend(
  762. [s.seq_group for s in swapped_in.decode_seq_groups])
  763. # Update swapped requests.
  764. self.swapped.extend(running_scheduled.swapped_out)
  765. preempted = (len(running_scheduled.preempted) +
  766. len(running_scheduled.swapped_out))
  767. # There should be no prefill from running queue because this policy
  768. # doesn't allow chunked prefills.
  769. assert len(running_scheduled.prefill_seq_groups) == 0
  770. assert len(swapped_in.prefill_seq_groups) == 0
  771. # Merge lists
  772. num_prefill_groups = len(prefills.seq_groups)
  773. if num_prefill_groups > 0:
  774. scheduled_seq_groups = prefills.seq_groups
  775. scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
  776. else:
  777. scheduled_seq_groups = running_scheduled.decode_seq_groups
  778. scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
  779. blocks_to_copy = running_scheduled.blocks_to_copy
  780. blocks_to_copy.extend(swapped_in.blocks_to_copy)
  781. ignored_seq_groups = prefills.ignored_seq_groups
  782. ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
  783. return SchedulerOutputs(
  784. scheduled_seq_groups=scheduled_seq_groups,
  785. num_prefill_groups=num_prefill_groups,
  786. num_batched_tokens=budget.num_batched_tokens,
  787. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  788. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  789. blocks_to_copy=blocks_to_copy,
  790. ignored_seq_groups=ignored_seq_groups,
  791. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  792. running_queue_size=len(self.running),
  793. preempted=preempted,
  794. )
  795. def _schedule_chunked_prefill(self) -> SchedulerOutputs:
  796. """Schedule queued requests.
  797. Chunked prefill allows to chunk prefill requests, batch them together
  798. with decode requests. This policy 1. schedule as many decoding requests
  799. as possible. 2. schedule chunked prefill requests that are not
  800. finished. 3. schedule swapped request. 4. schedule new prefill
  801. requests.
  802. The policy can sustain the high GPU utilization because it can put
  803. prefill and decodes requests to the same batch, while it improves
  804. inter token latency because decodes requests don't need to be blocked
  805. by prefill requests.
  806. """
  807. budget = SchedulingBudget(
  808. token_budget=self.scheduler_config.max_num_batched_tokens,
  809. max_num_seqs=self.scheduler_config.max_num_seqs,
  810. )
  811. curr_loras: Set[int] = set()
  812. prefills = SchedulerPrefillOutputs.create_empty()
  813. swapped_in = SchedulerSwappedInOutputs.create_empty()
  814. # Decoding should be always scheduled first by fcfs.
  815. running_scheduled = self._schedule_running(budget,
  816. curr_loras,
  817. enable_chunking=True)
  818. # Schedule swapped out requests.
  819. # If preemption happens, it means we don't have space for swap-in.
  820. if len(running_scheduled.preempted) + len(
  821. running_scheduled.swapped_out) == 0:
  822. swapped_in = self._schedule_swapped(budget, curr_loras)
  823. # Schedule new prefills.
  824. prefills = self._schedule_prefills(budget,
  825. curr_loras,
  826. enable_chunking=True)
  827. assert (budget.num_batched_tokens <=
  828. self.scheduler_config.max_num_batched_tokens)
  829. assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
  830. # Update waiting requests.
  831. self.waiting.extendleft(running_scheduled.preempted)
  832. # Update new running requests.
  833. self.running.extend([s.seq_group for s in prefills.seq_groups])
  834. self.running.extend(
  835. [s.seq_group for s in running_scheduled.decode_seq_groups])
  836. self.running.extend(
  837. [s.seq_group for s in running_scheduled.prefill_seq_groups])
  838. self.running.extend(
  839. [s.seq_group for s in swapped_in.decode_seq_groups])
  840. self.running.extend(
  841. [s.seq_group for s in swapped_in.prefill_seq_groups])
  842. # Update swapped requests.
  843. self.swapped.extend(running_scheduled.swapped_out)
  844. return SchedulerOutputs(
  845. scheduled_seq_groups=(prefills.seq_groups +
  846. running_scheduled.prefill_seq_groups +
  847. swapped_in.prefill_seq_groups +
  848. running_scheduled.decode_seq_groups +
  849. swapped_in.decode_seq_groups),
  850. num_prefill_groups=(len(prefills.seq_groups) +
  851. len(swapped_in.prefill_seq_groups) +
  852. len(running_scheduled.prefill_seq_groups)),
  853. num_batched_tokens=budget.num_batched_tokens,
  854. blocks_to_swap_in=swapped_in.blocks_to_swap_in,
  855. blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
  856. blocks_to_copy=running_scheduled.blocks_to_copy +
  857. swapped_in.blocks_to_copy,
  858. ignored_seq_groups=prefills.ignored_seq_groups +
  859. swapped_in.infeasible_seq_groups,
  860. num_lookahead_slots=running_scheduled.num_lookahead_slots,
  861. running_queue_size=len(self.running),
  862. preempted=(len(running_scheduled.preempted) +
  863. len(running_scheduled.swapped_out)),
  864. )
  865. def _schedule(self) -> SchedulerOutputs:
  866. """Schedule queued requests."""
  867. if self.scheduler_config.chunked_prefill_enabled:
  868. return self._schedule_chunked_prefill()
  869. else:
  870. return self._schedule_default()
  871. def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
  872. """Determine whether or not we have enough space in the KV cache to
  873. continue generation of the sequence group.
  874. """
  875. # It is True only for testing case to trigger artificial preemption.
  876. if (self.enable_artificial_preemption
  877. and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
  878. and self.artificial_preempt_cnt > 0):
  879. self.artificial_preempt_cnt -= 1
  880. return False
  881. # Appending slots only occurs in decoding.
  882. is_prefill = False
  883. return self.block_manager.can_append_slots(
  884. seq_group=seq_group,
  885. num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
  886. )
  887. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  888. # Schedule sequence groups.
  889. # This function call changes the internal states of the scheduler
  890. # such as self.running, self.swapped, and self.waiting.
  891. scheduler_outputs = self._schedule()
  892. now = time.time()
  893. # Create input data structures.
  894. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  895. for i, scheduled_seq_group in enumerate(
  896. scheduler_outputs.scheduled_seq_groups):
  897. seq_group = scheduled_seq_group.seq_group
  898. token_chunk_size = scheduled_seq_group.token_chunk_size
  899. seq_group.maybe_set_first_scheduled_time(now)
  900. # seq_id -> SequenceData
  901. seq_data: Dict[int, SequenceData] = {}
  902. # seq_id -> physical block numbers
  903. block_tables: Dict[int, List[int]] = {}
  904. if seq_group.is_encoder_decoder():
  905. # Encoder associated with SequenceGroup
  906. encoder_seq_data = seq_group.get_encoder_seq().data
  907. # Block table for cross-attention
  908. # Also managed at SequenceGroup level
  909. cross_block_table = self.block_manager.get_cross_block_table(
  910. seq_group)
  911. else:
  912. encoder_seq_data = None
  913. cross_block_table = None
  914. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  915. seq_id = seq.seq_id
  916. seq_data[seq_id] = seq.data
  917. block_tables[seq_id] = self.block_manager.get_block_table(seq)
  918. self.block_manager.access_all_blocks_in_seq(seq, now)
  919. if (self.cache_config.enable_prefix_caching and
  920. not seq_group.sampling_params.prompt_logprobs):
  921. common_computed_block_nums = (
  922. self.block_manager.get_common_computed_block_ids(
  923. seq_group.get_seqs(status=SequenceStatus.RUNNING)))
  924. else:
  925. common_computed_block_nums = []
  926. do_sample = True
  927. is_prompt = seq_group.is_prefill()
  928. # We should send the metadata to workers when the first prefill
  929. # is sent. Subsequent requests could be chunked prefill or decode.
  930. is_first_prefill = False
  931. if is_prompt:
  932. seqs = seq_group.get_seqs()
  933. # Prefill has only 1 sequence.
  934. assert len(seqs) == 1
  935. num_computed_tokens = seqs[0].data.get_num_computed_tokens()
  936. is_first_prefill = num_computed_tokens == 0
  937. # In the next iteration, all prompt tokens are not computed.
  938. # It means the prefill is chunked, and we don't need sampling.
  939. # NOTE: We use get_len instead of get_prompt_len because when
  940. # a sequence is preempted, prefill includes previous generated
  941. # output tokens.
  942. if (token_chunk_size + num_computed_tokens <
  943. seqs[0].data.get_len()):
  944. do_sample = False
  945. # It assumes the scheduled_seq_groups is ordered by
  946. # prefill < decoding.
  947. if is_first_prefill or not self.scheduler_config.send_delta_data:
  948. seq_group_metadata = SequenceGroupMetadata(
  949. request_id=seq_group.request_id,
  950. is_prompt=is_prompt,
  951. seq_data=seq_data,
  952. sampling_params=seq_group.sampling_params,
  953. block_tables=block_tables,
  954. do_sample=do_sample,
  955. pooling_params=seq_group.pooling_params,
  956. token_chunk_size=token_chunk_size,
  957. lora_request=seq_group.lora_request,
  958. computed_block_nums=common_computed_block_nums,
  959. encoder_seq_data=encoder_seq_data,
  960. cross_block_table=cross_block_table,
  961. state=seq_group.state,
  962. # `multi_modal_data` will only be present for the 1st comm
  963. # between engine and worker.
  964. # the subsequent comms can still use delta, but
  965. # `multi_modal_data` will be None.
  966. multi_modal_data=seq_group.multi_modal_data
  967. if scheduler_outputs.num_prefill_groups > 0 else None,
  968. prompt_adapter_request=seq_group.prompt_adapter_request,
  969. )
  970. else:
  971. # When SPMD mode is enabled, we only send delta data except for
  972. # the first request to reduce serialization cost.
  973. seq_data_delta = {}
  974. for id, data in seq_data.items():
  975. seq_data_delta[id] = data.get_delta_and_reset()
  976. seq_group_metadata = SequenceGroupMetadataDelta(
  977. seq_data_delta,
  978. seq_group.request_id,
  979. block_tables,
  980. is_prompt,
  981. do_sample=do_sample,
  982. token_chunk_size=token_chunk_size,
  983. computed_block_nums=common_computed_block_nums,
  984. )
  985. seq_group_metadata_list.append(seq_group_metadata)
  986. # Now that the batch has been created, we can assume all blocks in the
  987. # batch will have been computed before the next scheduling invocation.
  988. # This is because the engine assumes that a failure in model execution
  989. # will crash the Aphrodite instance / will not retry.
  990. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
  991. self.block_manager.mark_blocks_as_computed(
  992. scheduled_seq_group.seq_group,
  993. scheduled_seq_group.token_chunk_size)
  994. return seq_group_metadata_list, scheduler_outputs
  995. def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
  996. self.block_manager.fork(parent_seq, child_seq)
  997. def free_seq(self, seq: Sequence) -> None:
  998. """Free a sequence from a block table."""
  999. self.block_manager.free(seq)
  1000. def free_finished_seq_groups(self) -> None:
  1001. remaining: Deque[SequenceGroup] = deque()
  1002. for seq_group in self.running:
  1003. if seq_group.is_finished():
  1004. # Free cross-attention block table, if it exists
  1005. self._free_seq_group_cross_attn_blocks(seq_group)
  1006. # Add the finished requests to the finished requests list.
  1007. # This list will be used to update the Mamba cache in the
  1008. # next step.
  1009. self._finished_requests_ids.append(seq_group.request_id)
  1010. else:
  1011. remaining.append(seq_group)
  1012. self.running = remaining
  1013. def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
  1014. self.block_manager.allocate(seq_group)
  1015. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
  1016. seq.status = SequenceStatus.RUNNING
  1017. def _append_slots(
  1018. self,
  1019. seq_group: SequenceGroup,
  1020. blocks_to_copy: List[Tuple[int, int]],
  1021. ) -> None:
  1022. """Appends new slots to the sequences in the given sequence group.
  1023. Args:
  1024. seq_group (SequenceGroup): The sequence group containing the
  1025. sequences to append slots to.
  1026. blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
  1027. ints, the first int is the source block index, and the second
  1028. int is the destination block index. This list is updated with
  1029. the new source and destination block indices for the appended
  1030. slots.
  1031. """
  1032. num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
  1033. seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
  1034. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  1035. cows = self.block_manager.append_slots(seq, num_lookahead_slots)
  1036. if len(cows) > 0:
  1037. blocks_to_copy.extend(cows)
  1038. def _preempt(
  1039. self,
  1040. seq_group: SequenceGroup,
  1041. blocks_to_swap_out: List[Tuple[int, int]],
  1042. preemption_mode: Optional[PreemptionMode] = None,
  1043. ) -> PreemptionMode:
  1044. # If preemption mode is not specified, we determine the mode as follows:
  1045. # We use recomputation by default since it incurs lower overhead than
  1046. # swapping. However, when the sequence group has multiple sequences
  1047. # (e.g., beam search), recomputation is not currently supported. In
  1048. # such a case, we use swapping instead.
  1049. # FIXME: This makes our scheduling policy a bit bizarre.
  1050. # As swapped sequences are prioritized over waiting sequences,
  1051. # sequence groups with multiple sequences are implicitly prioritized
  1052. # over sequence groups with a single sequence.
  1053. # TODO: Support recomputation for sequence groups with multiple
  1054. # sequences. This may require a more sophisticated CUDA kernel.
  1055. if self.user_specified_preemption_mode is None:
  1056. if seq_group.get_max_num_running_seqs() == 1:
  1057. preemption_mode = PreemptionMode.RECOMPUTE
  1058. else:
  1059. preemption_mode = PreemptionMode.SWAP
  1060. elif self.user_specified_preemption_mode == "swap":
  1061. preemption_mode = PreemptionMode.SWAP
  1062. else:
  1063. preemption_mode = PreemptionMode.RECOMPUTE
  1064. if self.num_cumulative_preemption % 50 == 0:
  1065. logger.warning(
  1066. f"Sequence group {seq_group.request_id} is preempted by "
  1067. f"{preemption_mode} mode because there is "
  1068. "not enough KV cache space. This can affect the end-to-end "
  1069. "performance. Increase gpu_memory_utilization or "
  1070. "tensor_parallel_size to provide more KV cache memory. "
  1071. "total_num_cumulative_preemption="
  1072. f"{self.num_cumulative_preemption + 1}")
  1073. self.num_cumulative_preemption += 1
  1074. if preemption_mode == PreemptionMode.RECOMPUTE:
  1075. self._preempt_by_recompute(seq_group)
  1076. elif preemption_mode == PreemptionMode.SWAP:
  1077. self._preempt_by_swap(seq_group, blocks_to_swap_out)
  1078. else:
  1079. raise AssertionError("Invalid preemption mode.")
  1080. return preemption_mode
  1081. def _preempt_by_recompute(
  1082. self,
  1083. seq_group: SequenceGroup,
  1084. ) -> None:
  1085. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  1086. assert len(seqs) == 1
  1087. for seq in seqs:
  1088. seq.status = SequenceStatus.WAITING
  1089. self.free_seq(seq)
  1090. seq.reset_state_for_recompute()
  1091. def _preempt_by_swap(
  1092. self,
  1093. seq_group: SequenceGroup,
  1094. blocks_to_swap_out: List[Tuple[int, int]],
  1095. ) -> None:
  1096. self._swap_out(seq_group, blocks_to_swap_out)
  1097. def _swap_in(
  1098. self,
  1099. seq_group: SequenceGroup,
  1100. blocks_to_swap_in: List[Tuple[int, int]],
  1101. ) -> None:
  1102. mapping = self.block_manager.swap_in(seq_group)
  1103. blocks_to_swap_in.extend(mapping)
  1104. for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
  1105. seq.status = SequenceStatus.RUNNING
  1106. def _swap_out(
  1107. self,
  1108. seq_group: SequenceGroup,
  1109. blocks_to_swap_out: List[Tuple[int, int]],
  1110. ) -> None:
  1111. if not self.block_manager.can_swap_out(seq_group):
  1112. # FIXME: Abort the sequence group instead of aborting the
  1113. # entire engine.
  1114. raise RuntimeError(
  1115. "Aborted due to the lack of CPU swap space. Please increase "
  1116. "the swap space to avoid this error.")
  1117. mapping = self.block_manager.swap_out(seq_group)
  1118. blocks_to_swap_out.extend(mapping)
  1119. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  1120. seq.status = SequenceStatus.SWAPPED
  1121. def _passed_delay(self, now: float) -> bool:
  1122. if self.prev_prompt:
  1123. self.last_prompt_latency = now - self.prev_time
  1124. self.prev_time, self.prev_prompt = now, False
  1125. # Delay scheduling prompts to let waiting queue fill up
  1126. if self.scheduler_config.delay_factor > 0 and self.waiting:
  1127. earliest_arrival_time = min(
  1128. [e.metrics.arrival_time for e in self.waiting])
  1129. passed_delay = (
  1130. (now - earliest_arrival_time) >
  1131. (self.scheduler_config.delay_factor * self.last_prompt_latency)
  1132. or not self.running)
  1133. else:
  1134. passed_delay = True
  1135. return passed_delay
  1136. def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
  1137. """The number of slots to allocate per sequence per step, beyond known
  1138. token ids. Speculative decoding uses these slots to store KV activations
  1139. of tokens which may or may not be accepted.
  1140. Speculative decoding does not yet support prefill, so we do not perform
  1141. lookahead allocation for prefill.
  1142. """
  1143. if is_prefill:
  1144. return 0
  1145. return self.scheduler_config.num_lookahead_slots
  1146. def _get_num_new_tokens(self, seq_group: SequenceGroup,
  1147. status: SequenceStatus, enable_chunking: bool,
  1148. budget: SchedulingBudget) -> int:
  1149. """Get the next new tokens to compute for a given sequence group
  1150. that's in a given `status`.
  1151. The API could chunk the number of tokens to compute based on `budget`
  1152. if `enable_chunking` is True. If a sequence group has multiple
  1153. sequences (e.g., running beam search), it means it is in decoding
  1154. phase, so chunking doesn't happen.
  1155. Returns 0 if the new token cannot be computed due to token budget.
  1156. """
  1157. num_new_tokens = 0
  1158. seqs = seq_group.get_seqs(status=status)
  1159. for seq in seqs:
  1160. num_new_tokens += seq.get_num_new_tokens()
  1161. assert num_new_tokens > 0
  1162. # Chunk if a running request cannot fit in the given budget.
  1163. # If number of seq > 1, it means it is doing beam search
  1164. # in a decode phase. Do not chunk.
  1165. if enable_chunking and len(seqs) == 1:
  1166. remaining_token_budget = budget.remaining_token_budget()
  1167. if self.cache_config.enable_prefix_caching:
  1168. # When prefix caching is enabled, we always allocate
  1169. # the number of new tokens that is dividable by the block size
  1170. # to avoid partial block matching.
  1171. block_size = self.cache_config.block_size
  1172. reminder = budget.token_budget % block_size
  1173. if reminder != 0:
  1174. raise ValueError("When enabling chunked prefill and "
  1175. "prefix caching, max_num_batched_tokens "
  1176. "(chunk size) must be dividable by "
  1177. "block size, but got chunk_size "
  1178. f"({budget.token_budget}) % block_size "
  1179. f"({block_size}) = {reminder}")
  1180. if remaining_token_budget < num_new_tokens:
  1181. num_new_tokens = (remaining_token_budget //
  1182. block_size) * block_size
  1183. else:
  1184. num_new_tokens = min(num_new_tokens, remaining_token_budget)
  1185. return num_new_tokens