1
0

test_scheduler.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. import time
  2. from collections import deque
  3. from typing import List, Set, Tuple
  4. from unittest.mock import MagicMock
  5. import pytest # noqa
  6. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  7. from aphrodite.common.sequence import SequenceGroup, SequenceStatus
  8. from aphrodite.lora.request import LoRARequest
  9. from aphrodite.processing.interfaces import AllocStatus
  10. from aphrodite.processing.scheduler import Scheduler, SchedulingBudget
  11. from .utils import (append_new_token, append_new_token_seq_group,
  12. create_dummy_prompt, get_sequence_groups,
  13. schedule_and_update_computed_tokens)
  14. def test_scheduler_add_seq_group():
  15. block_size = 4
  16. scheduler_config = SchedulerConfig(100, 64, 1)
  17. cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto",)
  18. cache_config.num_cpu_blocks = 4
  19. cache_config.num_gpu_blocks = 4
  20. scheduler = Scheduler(scheduler_config, cache_config, None)
  21. # Add seq group to scheduler.
  22. num_seq_group = 4
  23. for i in range(num_seq_group):
  24. _, seq_group = create_dummy_prompt(str(i), block_size)
  25. scheduler.add_seq_group(seq_group)
  26. assert scheduler.get_num_unfinished_seq_groups() == i + 1
  27. def test_scheduler_abort_seq_group():
  28. block_size = 4
  29. scheduler_config = SchedulerConfig(100, 64, 1)
  30. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  31. cache_config.num_cpu_blocks = 4
  32. cache_config.num_gpu_blocks = 4
  33. scheduler = Scheduler(scheduler_config, cache_config, None)
  34. # Add multiple seq groups to scheduler.
  35. num_seq_group = 4
  36. request_ids: Set[str] = set()
  37. for i in range(num_seq_group):
  38. _, seq_group = create_dummy_prompt(str(i), block_size)
  39. scheduler.add_seq_group(seq_group)
  40. request_ids.add(str(i))
  41. # Abort all added seq groups.
  42. assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
  43. scheduler.abort_seq_group(request_ids)
  44. assert scheduler.get_num_unfinished_seq_groups() == 0
  45. def test_scheduler_schedule_simple():
  46. block_size = 4
  47. num_seq_group = 4
  48. max_model_len = 16
  49. scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
  50. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  51. cache_config.num_cpu_blocks = 8
  52. cache_config.num_gpu_blocks = 8
  53. scheduler = Scheduler(scheduler_config, cache_config, None)
  54. running: List[SequenceGroup] = []
  55. # Add seq groups to scheduler.
  56. for i in range(num_seq_group):
  57. _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
  58. scheduler.add_seq_group(seq_group)
  59. running.append(seq_group)
  60. # Schedule seq groups prompts.
  61. num_tokens = block_size * num_seq_group
  62. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  63. assert set(get_sequence_groups(out)) == set(running)
  64. assert out.num_batched_tokens == num_tokens
  65. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  66. and not out.blocks_to_swap_out)
  67. assert len(seq_group_meta) == num_seq_group
  68. append_new_token(out, 1)
  69. # Schedule seq groups generation.
  70. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  71. assert set(get_sequence_groups(out)) == set(running)
  72. assert out.num_batched_tokens == num_seq_group
  73. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  74. and not out.blocks_to_swap_out)
  75. assert len(seq_group_meta) == num_seq_group
  76. append_new_token(out, 1)
  77. def test_scheduler_prefill_prioritized():
  78. """Verify running batched tokens are not applied to prefill requests."""
  79. block_size = 4
  80. max_model_len = 30
  81. max_batched_num_tokens = 30
  82. scheduler_config = SchedulerConfig(max_batched_num_tokens, 2,
  83. max_model_len)
  84. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  85. cache_config.num_cpu_blocks = 2
  86. cache_config.num_gpu_blocks = 2
  87. scheduler = Scheduler(scheduler_config, cache_config, None)
  88. # Add seq groups to scheduler.
  89. _, seq_group_a = create_dummy_prompt("1", 1)
  90. scheduler.add_seq_group(seq_group_a)
  91. # Schedule seq groups prompts.
  92. _, out = schedule_and_update_computed_tokens(scheduler)
  93. assert get_sequence_groups(out) == [seq_group_a]
  94. # Add a new prefill request B.
  95. _, seq_group_b = create_dummy_prompt("2", 30)
  96. scheduler.add_seq_group(seq_group_b)
  97. # Verify prefill requests are prioritized. Since max_batched_num_tokens
  98. # is 1, new prefill request has to be scheduled first.
  99. _, out = schedule_and_update_computed_tokens(scheduler)
  100. assert get_sequence_groups(out) == [seq_group_b]
  101. def test_scheduler_schedule_preempt_abort():
  102. block_size = 4
  103. max_model_len = 16
  104. scheduler_config = SchedulerConfig(64, 2, max_model_len)
  105. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  106. cache_config.num_cpu_blocks = 2
  107. cache_config.num_gpu_blocks = 2
  108. scheduler = Scheduler(scheduler_config, cache_config, None)
  109. # Add seq groups to scheduler.
  110. seq_a, seq_group_a = create_dummy_prompt("1", block_size)
  111. seq_b, seq_group_b = create_dummy_prompt("2", block_size)
  112. scheduler.add_seq_group(seq_group_a)
  113. scheduler.add_seq_group(seq_group_b)
  114. # Schedule seq groups prompts.
  115. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  116. assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
  117. assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
  118. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  119. and not out.blocks_to_swap_out)
  120. assert len(seq_group_meta) == 2
  121. assert scheduler.get_num_unfinished_seq_groups() == 2
  122. # Append "generated" tokens, allowing the sequence to mark prompt tokens as
  123. # processed.
  124. append_new_token(out, 1)
  125. # Schedule seq groups generation and preempt seq group b.
  126. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  127. assert get_sequence_groups(out) == [seq_group_a]
  128. assert out.num_batched_tokens == 1
  129. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  130. and not out.blocks_to_swap_out)
  131. assert len(seq_group_meta) == 1
  132. assert scheduler.get_num_unfinished_seq_groups() == 2
  133. assert out.preempted == 1
  134. # Abort seq group a. Re-schedule seq group b prompt with recomputation.
  135. scheduler.abort_seq_group("1")
  136. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  137. assert get_sequence_groups(out) == [seq_group_b]
  138. assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
  139. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  140. and not out.blocks_to_swap_out)
  141. assert len(seq_group_meta) == 1
  142. assert scheduler.get_num_unfinished_seq_groups() == 1
  143. def test_scheduler_max_seqs():
  144. block_size = 4
  145. num_seq_group = 4
  146. max_seq_group = 2
  147. max_model_len = 16
  148. scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
  149. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  150. cache_config.num_cpu_blocks = 8
  151. cache_config.num_gpu_blocks = 8
  152. scheduler = Scheduler(scheduler_config, cache_config, None)
  153. all_seq_groups: List[SequenceGroup] = []
  154. # Add seq groups to scheduler.
  155. for i in range(num_seq_group):
  156. _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
  157. all_seq_groups.append(seq_group)
  158. # Append 1 seq group
  159. scheduler.add_seq_group(all_seq_groups[0])
  160. # Schedule seq groups prompts.
  161. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  162. assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
  163. append_new_token(out, 1)
  164. # Schedule seq groups generation.
  165. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  166. assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
  167. append_new_token(out, 1)
  168. # Append 2 more seq group
  169. scheduler.add_seq_group(all_seq_groups[1])
  170. scheduler.add_seq_group(all_seq_groups[2])
  171. # Schedule seq groups prompts.
  172. # Only 1 seq group should be scheduled since max_seq_group is 2
  173. # and one is prompting.
  174. _, out = schedule_and_update_computed_tokens(scheduler)
  175. assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
  176. def test_scheduler_delay_factor():
  177. block_size = 4
  178. scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5)
  179. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  180. cache_config.num_cpu_blocks = 8
  181. cache_config.num_gpu_blocks = 8
  182. scheduler = Scheduler(scheduler_config, cache_config, None)
  183. # schedule first prompt
  184. seq_group_meta, seq_group = create_dummy_prompt("0",
  185. prompt_length=block_size)
  186. scheduler.add_seq_group(seq_group)
  187. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  188. assert out.num_prefill_groups > 0
  189. assert seq_group_meta[0].request_id == '0'
  190. append_new_token(out, 1)
  191. # wait for a second before scheduling next prompt
  192. time.sleep(1)
  193. seq_group_meta, seq_group = create_dummy_prompt("1",
  194. prompt_length=block_size)
  195. scheduler.add_seq_group(seq_group)
  196. # second prompt should *not* be scheduled
  197. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  198. assert out.num_prefill_groups == 0
  199. assert seq_group_meta[0].request_id == '0'
  200. append_new_token(out, 1)
  201. # wait for more than 0.5 second and try again
  202. time.sleep(0.6)
  203. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  204. assert out.num_prefill_groups > 0
  205. assert seq_group_meta[0].request_id == '1'
  206. append_new_token(out, 1)
  207. def test_swapped_out_prioritized():
  208. scheduler = initialize_scheduler(max_num_seqs=6)
  209. # best_of=2 * 3 == 6 sequences.
  210. for i in range(3):
  211. _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
  212. scheduler.add_seq_group(seq_group)
  213. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  214. # prefill scheduled now.
  215. assert len(out.scheduled_seq_groups) == 3
  216. append_new_token(out, 1)
  217. # The last request should be swapped out.
  218. scheduler.block_manager.can_append_slots = MagicMock()
  219. def cannot_append_second_group(seq_group, num_lookahead_slots):
  220. return seq_group.request_id != "2"
  221. scheduler.block_manager.can_append_slots.side_effect = (
  222. cannot_append_second_group)
  223. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  224. assert len(out.scheduled_seq_groups) == 2
  225. assert out.num_batched_tokens == 2
  226. assert out.blocks_to_swap_out != []
  227. assert out.blocks_to_swap_in == []
  228. append_new_token(out, 1)
  229. # Add 1 more task. Swap should be prioritized over prefill.
  230. _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
  231. scheduler.add_seq_group(seq_group)
  232. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  233. append_new_token(out, 1)
  234. assert len(out.scheduled_seq_groups) == 3
  235. # 3 decodes. It is swapped in.
  236. assert out.num_batched_tokens == 3
  237. assert out.blocks_to_swap_in != []
  238. assert out.blocks_to_swap_out == []
  239. def initialize_scheduler(*,
  240. max_num_seqs=1000,
  241. max_token_budget=1000,
  242. max_model_len=1000,
  243. lora_config=None):
  244. block_size = 4
  245. scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs,
  246. max_model_len)
  247. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  248. cache_config.num_cpu_blocks = 8
  249. cache_config.num_gpu_blocks = 8
  250. scheduler = Scheduler(scheduler_config, cache_config, lora_config)
  251. return scheduler
  252. def create_token_budget(token_budget: int = 10000,
  253. max_num_seqs: int = 10000) -> SchedulingBudget:
  254. return SchedulingBudget(
  255. token_budget=token_budget,
  256. max_num_seqs=max_num_seqs,
  257. )
  258. def add_token_budget(budget: SchedulingBudget,
  259. num_batched_tokens: int = 0,
  260. num_curr_seqs: int = 0):
  261. mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
  262. budget.add_num_batched_tokens(mock_seq_group.request_id,
  263. num_batched_tokens)
  264. budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)
  265. def test_prefill_schedule_max_prompt_len():
  266. """
  267. Test prompt longer than max_prompt_len is aborted.
  268. """
  269. scheduler = initialize_scheduler(max_model_len=30)
  270. _, seq_group = create_dummy_prompt("0", prompt_length=60)
  271. scheduler.add_seq_group(seq_group)
  272. budget = create_token_budget()
  273. output = scheduler._schedule_prefills(budget, None)
  274. remaining_waiting = scheduler.waiting
  275. assert len(output.ignored_seq_groups) == 1
  276. assert len(output.seq_groups) == 0
  277. assert budget.num_batched_tokens == 0
  278. assert budget.num_curr_seqs == 0
  279. assert len(remaining_waiting) == 0
  280. def test_prefill_schedule_token_budget():
  281. """
  282. Test token budget respected.
  283. """
  284. scheduler = initialize_scheduler()
  285. budget = create_token_budget(token_budget=0)
  286. for i in range(2):
  287. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  288. scheduler.add_seq_group(seq_group)
  289. # 0 token budget == nothing is scheduled.
  290. output = scheduler._schedule_prefills(budget, None)
  291. remaining_waiting = scheduler.waiting
  292. assert len(output.ignored_seq_groups) == 0
  293. assert len(output.seq_groups) == 0
  294. assert budget.num_batched_tokens == 0
  295. assert budget.num_curr_seqs == 0
  296. assert len(remaining_waiting) == 2
  297. # 60 token budget == 1 request scheduled.
  298. budget = create_token_budget(token_budget=60)
  299. output = scheduler._schedule_prefills(budget, None)
  300. remaining_waiting = scheduler.waiting
  301. assert len(output.ignored_seq_groups) == 0
  302. assert len(output.seq_groups) == 1
  303. assert budget.num_batched_tokens == 60
  304. assert budget.num_curr_seqs == 1
  305. assert len(remaining_waiting) == 1
  306. # Test when current_batched_tokens respected.
  307. scheduler = initialize_scheduler()
  308. budget = create_token_budget(token_budget=60)
  309. add_token_budget(budget, 30, 0)
  310. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  311. # Cannot schedule a prompt that doesn't fit the budget.
  312. scheduler.add_seq_group(seq_group)
  313. output = scheduler._schedule_prefills(budget, None)
  314. remaining_waiting = scheduler.waiting
  315. assert len(output.ignored_seq_groups) == 0
  316. assert len(output.seq_groups) == 0
  317. assert budget.num_batched_tokens == 30
  318. assert budget.num_curr_seqs == 0
  319. assert len(remaining_waiting) == 1
  320. budget = create_token_budget(token_budget=90)
  321. add_token_budget(budget, 30, 0)
  322. output = scheduler._schedule_prefills(budget, None)
  323. remaining_waiting = scheduler.waiting
  324. assert len(output.seq_groups) == 1
  325. assert budget.num_batched_tokens == 90
  326. assert budget.num_curr_seqs == 1
  327. assert len(remaining_waiting) == 0
  328. def test_prefill_schedule_max_seqs():
  329. """
  330. Test max seq respected.
  331. """
  332. scheduler = initialize_scheduler()
  333. budget = create_token_budget(max_num_seqs=2)
  334. for i in range(3):
  335. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  336. scheduler.add_seq_group(seq_group)
  337. output = scheduler._schedule_prefills(budget, None)
  338. remaining_waiting = scheduler.waiting
  339. assert len(output.ignored_seq_groups) == 0
  340. assert len(output.seq_groups) == 2
  341. assert budget.num_batched_tokens == 120
  342. assert budget.num_curr_seqs == 2
  343. assert len(remaining_waiting) == 1
  344. # Verify curr_num_seqs respected.
  345. scheduler.waiting = deque()
  346. budget = create_token_budget(max_num_seqs=2)
  347. add_token_budget(budget, 0, 2)
  348. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  349. scheduler.add_seq_group(seq_group)
  350. output = scheduler._schedule_prefills(budget, None)
  351. remaining_waiting = scheduler.waiting
  352. assert len(output.ignored_seq_groups) == 0
  353. assert len(output.seq_groups) == 0
  354. assert budget.num_batched_tokens == 0
  355. assert budget.num_curr_seqs == 2
  356. assert len(remaining_waiting) == 1
  357. def test_prefill_schedule_max_lora():
  358. """
  359. Test max lora is respected and prioritized.
  360. """
  361. lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
  362. scheduler = initialize_scheduler(lora_config=lora_config)
  363. budget = create_token_budget(token_budget=120)
  364. curr_loras: Set[int] = set()
  365. for i in range(2):
  366. _, seq_group = create_dummy_prompt(str(i),
  367. prompt_length=60,
  368. lora_request=LoRARequest(
  369. lora_name=str(i),
  370. lora_int_id=i + 1,
  371. lora_path="abc"))
  372. scheduler.add_seq_group(seq_group)
  373. # Add two more requests to verify lora is prioritized.
  374. # 0: Lora, 1: Lora, 2: regular, 3: regular
  375. # In the first iteration, index 0, 2 is scheduled.
  376. # If a request is not scheduled because it hits max lora, it is
  377. # prioritized. Verify that.
  378. for i in range(2, 4):
  379. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  380. scheduler.add_seq_group(seq_group)
  381. # Schedule 2 requests (0 and 2)
  382. output = scheduler._schedule_prefills(budget, curr_loras)
  383. remaining_waiting = scheduler.waiting
  384. assert len(output.ignored_seq_groups) == 0
  385. assert len(output.seq_groups) == 2
  386. assert budget.num_batched_tokens == 120
  387. assert budget.num_curr_seqs == 2
  388. assert len(remaining_waiting) == 2
  389. assert len(curr_loras) == 1
  390. # The second lora request is scheduled next as FCFS policy.
  391. # Reset curr_loras so that it can be scheduled.
  392. curr_loras = set()
  393. budget = create_token_budget(token_budget=60)
  394. output = scheduler._schedule_prefills(budget, curr_loras)
  395. remaining_waiting = scheduler.waiting
  396. assert len(output.seq_groups) == 1
  397. assert output.seq_groups[0].seq_group.request_id == "1"
  398. assert len(remaining_waiting) == 1
  399. assert len(curr_loras) == 1
  400. assert budget.num_batched_tokens == 60
  401. def test_prefill_schedule_no_block_manager_capacity():
  402. """
  403. Test sequence cannot be scheduled due to block manager has no capacity.
  404. """
  405. scheduler = initialize_scheduler()
  406. budget = create_token_budget()
  407. for i in range(3):
  408. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  409. scheduler.add_seq_group(seq_group)
  410. scheduler.block_manager.can_allocate = MagicMock()
  411. scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
  412. output = scheduler._schedule_prefills(budget, None)
  413. remaining_waiting = scheduler.waiting
  414. assert len(output.ignored_seq_groups) == 0
  415. assert len(output.seq_groups) == 0
  416. assert budget.num_batched_tokens == 0
  417. assert budget.num_curr_seqs == 0
  418. assert len(remaining_waiting) == 3
  419. scheduler = initialize_scheduler()
  420. budget = create_token_budget()
  421. for i in range(3):
  422. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  423. scheduler.add_seq_group(seq_group)
  424. scheduler.block_manager.can_allocate = MagicMock()
  425. scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
  426. output = scheduler._schedule_prefills(budget, None)
  427. remaining_waiting = scheduler.waiting
  428. assert len(output.ignored_seq_groups) == 3
  429. assert len(output.seq_groups) == 0
  430. assert budget.num_batched_tokens == 0
  431. assert budget.num_curr_seqs == 0
  432. assert len(remaining_waiting) == 0
  433. def test_decode_schedule_preempted():
  434. """
  435. Test decodes cannot be scheduled and preempted.
  436. """
  437. scheduler = initialize_scheduler()
  438. curr_loras = None
  439. for i in range(3):
  440. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  441. scheduler._allocate_and_set_running(seq_group)
  442. append_new_token_seq_group(60, seq_group, 1)
  443. scheduler._add_seq_group_to_running(seq_group)
  444. scheduler.block_manager.can_append_slots = MagicMock()
  445. def cannot_append_second_group(seq_group, num_lookahead_slots):
  446. return seq_group.request_id != "1"
  447. scheduler.block_manager.can_append_slots.side_effect = (
  448. cannot_append_second_group)
  449. # 1 cannot be scheduled, and the lowest priority (request 2)
  450. # should be preempted. 1 will also be preempted.
  451. budget = create_token_budget()
  452. output = scheduler._schedule_running(budget, curr_loras)
  453. remainig_running = scheduler.running
  454. assert len(remainig_running) == 0
  455. assert len(output.decode_seq_groups) == 1
  456. assert len(output.prefill_seq_groups) == 0
  457. assert output.decode_seq_groups[0].seq_group.request_id == "0"
  458. assert len(output.preempted) == 2
  459. # Verify budgets are updated.
  460. assert budget.num_batched_tokens == 1
  461. # NOTE: When enable_chunk is False, num_seqs budget is not updated.
  462. # assert budget.num_curr_seqs == 1
  463. # Both should be preempted, not swapped.
  464. assert output.blocks_to_swap_out == []
  465. # Nothing is copied.
  466. assert output.blocks_to_copy == []
  467. def test_decode_swap_beam_search():
  468. """
  469. Test best_of > 1 swap out blocks
  470. """
  471. scheduler = initialize_scheduler()
  472. curr_loras = None
  473. budget = create_token_budget()
  474. for i in range(3):
  475. _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
  476. scheduler._allocate_and_set_running(seq_group)
  477. scheduler._add_seq_group_to_running(seq_group)
  478. append_new_token_seq_group(60, seq_group, 1)
  479. budget.add_num_seqs(seq_group.request_id,
  480. seq_group.get_max_num_running_seqs())
  481. budget.add_num_batched_tokens(
  482. seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))
  483. # The last request should be swapped out.
  484. scheduler.block_manager.can_append_slots = MagicMock()
  485. def cannot_append_second_group(seq_group, num_lookahead_slots):
  486. return seq_group.request_id != "2"
  487. scheduler.block_manager.can_append_slots.side_effect = (
  488. cannot_append_second_group)
  489. scheduler.block_manager.swap_out = MagicMock()
  490. expected_swap_mapping = [("5", "7")]
  491. scheduler.block_manager.swap_out.return_value = expected_swap_mapping
  492. output = scheduler._schedule_running(budget, curr_loras)
  493. remainig_running = scheduler.running
  494. assert len(remainig_running) == 0
  495. assert len(output.decode_seq_groups) == 2
  496. assert len(output.prefill_seq_groups) == 0
  497. assert output.decode_seq_groups[0].seq_group.request_id == "0"
  498. assert output.decode_seq_groups[1].seq_group.request_id == "1"
  499. assert len(output.preempted) == 0
  500. assert len(output.swapped_out) == 1
  501. # Budget should refledct preempted requests.
  502. assert budget.num_batched_tokens == 2
  503. # since there are 2 sequences, 2 should be subtracted.
  504. assert budget.num_curr_seqs == 4
  505. # Both should be preempted, not swapped.
  506. assert output.blocks_to_swap_out == expected_swap_mapping
  507. # Nothing is copied.
  508. assert output.blocks_to_copy == []
  509. def test_schedule_decode_blocks_to_copy_update():
  510. """
  511. Verify blocks_to_copy is updated.
  512. """
  513. scheduler = initialize_scheduler()
  514. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  515. curr_loras = None
  516. scheduler._allocate_and_set_running(seq_group)
  517. append_new_token_seq_group(60, seq_group, 1)
  518. scheduler._add_seq_group_to_running(seq_group)
  519. # The last request should be swapped out.
  520. scheduler.block_manager.append_slots = MagicMock()
  521. scheduler.block_manager.append_slots.return_value = [(2, 3)]
  522. budget = create_token_budget()
  523. output = scheduler._schedule_running(budget, curr_loras)
  524. remaining_running = scheduler.running
  525. assert len(remaining_running) == 0
  526. assert len(output.decode_seq_groups) == 1
  527. assert len(output.prefill_seq_groups) == 0
  528. assert len(output.preempted) == 0
  529. assert len(output.swapped_out) == 0
  530. # Nothing is preempted.
  531. assert output.blocks_to_swap_out == []
  532. # Since append_slot returns the source -> dist mapping, it should
  533. # applied.
  534. assert output.blocks_to_copy == [(2, 3)]
  535. def test_schedule_swapped_simple():
  536. scheduler = initialize_scheduler()
  537. curr_loras = None
  538. blocks_to_swap_out: List[Tuple[int, int]] = []
  539. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  540. scheduler._allocate_and_set_running(seq_group)
  541. append_new_token_seq_group(60, seq_group, 1)
  542. scheduler._swap_out(seq_group, blocks_to_swap_out)
  543. scheduler._add_seq_group_to_swapped(seq_group)
  544. budget = create_token_budget()
  545. output = scheduler._schedule_swapped(budget, curr_loras)
  546. remaining_swapped = scheduler.swapped
  547. assert len(remaining_swapped) == 0
  548. assert budget.num_batched_tokens == 1
  549. assert budget.num_curr_seqs == 2
  550. assert len(output.decode_seq_groups) == 1
  551. assert len(output.prefill_seq_groups) == 0
  552. # swap in is the reverse of swap out
  553. blocks_to_swap_in_reverse = []
  554. for swapin, swapout in output.blocks_to_swap_in:
  555. blocks_to_swap_in_reverse.append((swapout, swapin))
  556. assert blocks_to_swap_out == blocks_to_swap_in_reverse
  557. def test_schedule_swapped_max_token_budget():
  558. scheduler = initialize_scheduler()
  559. curr_loras = None
  560. blocks_to_swap_out: List[Tuple[int, int]] = []
  561. for _ in range(2):
  562. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  563. scheduler._allocate_and_set_running(seq_group)
  564. append_new_token_seq_group(60, seq_group, 1)
  565. scheduler._swap_out(seq_group, blocks_to_swap_out)
  566. scheduler._add_seq_group_to_swapped(seq_group)
  567. budget = create_token_budget(token_budget=1)
  568. output = scheduler._schedule_swapped(budget, curr_loras)
  569. remaining_swapped = scheduler.swapped
  570. assert len(remaining_swapped) == 1
  571. assert budget.num_batched_tokens == 1
  572. assert budget.num_curr_seqs == 2
  573. assert len(output.decode_seq_groups) == 1
  574. assert len(output.prefill_seq_groups) == 0
  575. # Verify num_batched_tokens are respected.
  576. budget = create_token_budget(token_budget=1)
  577. add_token_budget(budget, 1, 0)
  578. output = scheduler._schedule_swapped(budget, curr_loras)
  579. remaining_swapped = scheduler.swapped
  580. assert len(remaining_swapped) == 1
  581. assert budget.num_batched_tokens == 1
  582. assert budget.num_curr_seqs == 0
  583. assert len(output.decode_seq_groups) == 0
  584. assert len(output.prefill_seq_groups) == 0
  585. def test_schedule_swapped_max_seqs():
  586. scheduler = initialize_scheduler()
  587. curr_loras = None
  588. blocks_to_swap_out: List[Tuple[int, int]] = []
  589. for i in range(4):
  590. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  591. scheduler._allocate_and_set_running(seq_group)
  592. append_new_token_seq_group(60, seq_group, 1)
  593. scheduler._swap_out(seq_group, blocks_to_swap_out)
  594. scheduler._add_seq_group_to_swapped(seq_group)
  595. budget = create_token_budget(max_num_seqs=2)
  596. output = scheduler._schedule_swapped(budget, curr_loras)
  597. remaining_swapped = scheduler.swapped
  598. assert len(remaining_swapped) == 2
  599. assert budget.num_batched_tokens == 2
  600. assert budget.num_curr_seqs == 2
  601. assert len(output.decode_seq_groups) == 2
  602. assert len(output.prefill_seq_groups) == 0
  603. # Verify num_curr_seqs are respected.
  604. output = scheduler._schedule_swapped(budget, curr_loras)
  605. remaining_swapped = scheduler.swapped
  606. assert len(remaining_swapped) == 2
  607. assert budget.num_batched_tokens == 2
  608. assert budget.num_curr_seqs == 2
  609. assert len(output.decode_seq_groups) == 0
  610. assert len(output.prefill_seq_groups) == 0
  611. def test_schedule_swapped_max_loras():
  612. lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
  613. scheduler = initialize_scheduler(lora_config=lora_config)
  614. curr_loras: Set[int] = set()
  615. blocks_to_swap_out: List[Tuple[int, int]] = []
  616. for i in range(2):
  617. _, seq_group = create_dummy_prompt(str(i),
  618. prompt_length=60,
  619. lora_request=LoRARequest(
  620. lora_name=str(i),
  621. lora_int_id=i + 1,
  622. lora_path="abc"))
  623. scheduler._allocate_and_set_running(seq_group)
  624. append_new_token_seq_group(60, seq_group, 1)
  625. scheduler._swap_out(seq_group, blocks_to_swap_out)
  626. scheduler._add_seq_group_to_swapped(seq_group)
  627. budget = create_token_budget()
  628. output = scheduler._schedule_swapped(budget, curr_loras)
  629. remaining_swapped = scheduler.swapped
  630. assert len(remaining_swapped) == 1
  631. assert budget.num_batched_tokens == 1
  632. assert budget.num_curr_seqs == 1
  633. assert len(output.decode_seq_groups) == 1
  634. assert len(output.prefill_seq_groups) == 0
  635. assert len(curr_loras) == 1
  636. def test_schedule_swapped_cannot_swap_in():
  637. scheduler = initialize_scheduler()
  638. curr_loras = None
  639. blocks_to_swap_out: List[Tuple[int, int]] = []
  640. for _ in range(2):
  641. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  642. scheduler._allocate_and_set_running(seq_group)
  643. append_new_token_seq_group(60, seq_group, 1)
  644. scheduler._swap_out(seq_group, blocks_to_swap_out)
  645. scheduler._add_seq_group_to_swapped(seq_group)
  646. # The last request should be swapped out.
  647. scheduler.block_manager.can_swap_in = MagicMock()
  648. scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
  649. # Since we cannot swap in, none of the requests are swapped in.
  650. budget = create_token_budget()
  651. output = scheduler._schedule_swapped(budget, curr_loras)
  652. remaining_swapped = scheduler.swapped
  653. assert len(remaining_swapped) == 2
  654. assert budget.num_batched_tokens == 0
  655. assert budget.num_curr_seqs == 0
  656. assert len(output.decode_seq_groups) == 0
  657. assert len(output.prefill_seq_groups) == 0
  658. def test_infeasible_swap():
  659. scheduler = initialize_scheduler()
  660. curr_loras = None
  661. blocks_to_swap_out: List[Tuple[int, int]] = []
  662. for _ in range(2):
  663. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  664. scheduler._allocate_and_set_running(seq_group)
  665. append_new_token_seq_group(60, seq_group, 1)
  666. scheduler._swap_out(seq_group, blocks_to_swap_out)
  667. scheduler._add_seq_group_to_swapped(seq_group)
  668. # The last request should be swapped out.
  669. scheduler.block_manager.can_swap_in = MagicMock()
  670. scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
  671. # Since we cannot swap in, none of the requests are swapped in.
  672. budget = create_token_budget()
  673. output = scheduler._schedule_swapped(budget, curr_loras)
  674. remaining_swapped = scheduler.swapped
  675. assert len(remaining_swapped) == 0
  676. assert len(output.infeasible_seq_groups) == 2
  677. assert budget.num_batched_tokens == 0
  678. assert budget.num_curr_seqs == 0
  679. assert len(output.decode_seq_groups) == 0
  680. assert len(output.prefill_seq_groups) == 0
  681. def test_schedule_swapped_blocks_to_copy():
  682. scheduler = initialize_scheduler()
  683. curr_loras = None
  684. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  685. scheduler._allocate_and_set_running(seq_group)
  686. append_new_token_seq_group(60, seq_group, 1)
  687. blocks_to_swap_out: List[Tuple[int, int]] = []
  688. scheduler._swap_out(seq_group, blocks_to_swap_out)
  689. scheduler._add_seq_group_to_swapped(seq_group)
  690. # The last request should be swapped out.
  691. scheduler.block_manager.append_slots = MagicMock()
  692. scheduler.block_manager.append_slots.return_value = [(2, 3)]
  693. budget = create_token_budget()
  694. output = scheduler._schedule_swapped(budget, curr_loras)
  695. remaining_swapped = scheduler.swapped
  696. assert len(remaining_swapped) == 0
  697. assert len(output.decode_seq_groups) == 1
  698. assert len(output.prefill_seq_groups) == 0
  699. assert output.blocks_to_copy == [(2, 3)]
  700. def test_scheduling_budget():
  701. TOKEN_BUDGET = 4
  702. MAX_SEQS = 4
  703. budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
  704. assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
  705. assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
  706. assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
  707. assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
  708. assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
  709. assert budget.remaining_token_budget() == TOKEN_BUDGET
  710. # Verify add/subtract num batched tokens.
  711. _, seq_group = create_dummy_prompt("1", 3)
  712. budget.add_num_batched_tokens(seq_group.request_id, 2)
  713. assert budget.remaining_token_budget() == 2
  714. assert budget.num_batched_tokens == 2
  715. assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
  716. assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
  717. # Verify adding another seq group is no-op.
  718. budget.add_num_batched_tokens(seq_group.request_id, 2)
  719. assert budget.remaining_token_budget() == 2
  720. assert budget.num_batched_tokens == 2
  721. budget.subtract_num_batched_tokens(seq_group.request_id, 2)
  722. assert budget.remaining_token_budget() == 4
  723. assert budget.num_batched_tokens == 0
  724. budget.subtract_num_batched_tokens(seq_group.request_id, 2)
  725. assert budget.remaining_token_budget() == 4
  726. assert budget.num_batched_tokens == 0
  727. # Verify add/subtract max seqs.
  728. _, seq_group = create_dummy_prompt("1", 3)
  729. budget.add_num_seqs(seq_group.request_id, 2)
  730. assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
  731. assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
  732. assert budget.num_curr_seqs == 2
  733. # Verify adding another seq group is no-op.
  734. budget.add_num_seqs(seq_group.request_id, 2)
  735. assert budget.num_curr_seqs == 2
  736. budget.subtract_num_seqs(seq_group.request_id, 2)
  737. assert budget.num_curr_seqs == 0
  738. budget.subtract_num_seqs(seq_group.request_id, 2)
  739. assert budget.num_curr_seqs == 0