test_chunked_prefill_scheduler.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. from typing import List
  2. from unittest.mock import MagicMock
  3. import pytest # noqa
  4. from aphrodite.common.config import CacheConfig, SchedulerConfig
  5. from aphrodite.common.sequence import Logprob, SequenceGroup
  6. from aphrodite.processing.interfaces import AllocStatus
  7. from aphrodite.processing.scheduler import Scheduler
  8. from .utils import create_dummy_prompt
  9. def get_sequence_groups(scheduler_output):
  10. return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
  11. def append_new_token(seq_group, token_id: int):
  12. for seq in seq_group.get_seqs():
  13. seq.append_token_id(token_id, {token_id: Logprob(token_id)})
  14. def schedule_and_update_computed_tokens(scheduler):
  15. metas, out, _ = scheduler.schedule()
  16. for s, meta in zip(out.scheduled_seq_groups, metas):
  17. s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
  18. return metas, out
  19. def test_simple():
  20. """Verify basic scheduling works."""
  21. block_size = 4
  22. num_seq_group = 4
  23. max_model_len = 16
  24. max_num_batched_tokens = 64
  25. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  26. num_seq_group,
  27. max_model_len,
  28. enable_chunked_prefill=True,
  29. is_attention_free=False)
  30. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  31. is_attention_free=False)
  32. cache_config.num_cpu_blocks = 8
  33. cache_config.num_gpu_blocks = 8
  34. scheduler = Scheduler(scheduler_config, cache_config, None)
  35. running: List[SequenceGroup] = []
  36. # Add seq groups to scheduler.
  37. for i in range(num_seq_group):
  38. _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
  39. scheduler.add_seq_group(seq_group)
  40. running.append(seq_group)
  41. # Schedule seq groups prompts.
  42. num_tokens = block_size * num_seq_group
  43. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  44. assert set(get_sequence_groups(out)) == set(running)
  45. assert out.num_batched_tokens == num_tokens
  46. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  47. and not out.blocks_to_swap_out)
  48. assert len(seq_group_meta) == num_seq_group
  49. for s in running:
  50. append_new_token(s, 1)
  51. # Schedule seq groups generation.
  52. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  53. assert set(get_sequence_groups(out)) == set(running)
  54. assert out.num_batched_tokens == num_seq_group
  55. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  56. and not out.blocks_to_swap_out)
  57. assert len(seq_group_meta) == num_seq_group
  58. def test_chunk():
  59. """Verify prefills are chunked properly."""
  60. block_size = 4
  61. max_seqs = 60
  62. max_model_len = 80
  63. max_num_batched_tokens = 64
  64. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  65. max_seqs,
  66. max_model_len,
  67. enable_chunked_prefill=True,
  68. is_attention_free=False)
  69. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  70. cache_config.num_cpu_blocks = 8
  71. cache_config.num_gpu_blocks = 8
  72. scheduler = Scheduler(scheduler_config, cache_config, None)
  73. running: List[SequenceGroup] = []
  74. # Add seq groups to scheduler.
  75. for i in range(2):
  76. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  77. scheduler.add_seq_group(seq_group)
  78. running.append(seq_group)
  79. # Verify the second request is chunked.
  80. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  81. assert set(get_sequence_groups(out)) == set(running)
  82. assert seq_group_meta[0].token_chunk_size == 60
  83. # Verify it is chunked.
  84. assert seq_group_meta[1].token_chunk_size == 4
  85. assert out.num_prefill_groups == 2
  86. assert out.num_batched_tokens == 64
  87. # Only the first seq group has a new token appended.
  88. append_new_token(running[0], 1)
  89. # One chunked prefill, and one decoding.
  90. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  91. assert set(get_sequence_groups(out)) == set(running)
  92. # The first one is prefill. Scheduler guarantees ordering.
  93. assert seq_group_meta[0].token_chunk_size == 56
  94. # The second one is a chunked prefill.
  95. assert seq_group_meta[1].token_chunk_size == 1
  96. assert out.num_prefill_groups == 1
  97. assert out.num_batched_tokens == 57
  98. def test_complex():
  99. block_size = 4
  100. max_seqs = 60
  101. max_model_len = 80
  102. max_num_batched_tokens = 64
  103. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  104. max_seqs,
  105. max_model_len,
  106. enable_chunked_prefill=True,
  107. is_attention_free=False)
  108. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  109. is_attention_free=False)
  110. cache_config.num_cpu_blocks = 8
  111. cache_config.num_gpu_blocks = 8
  112. scheduler = Scheduler(scheduler_config, cache_config, None)
  113. running: List[SequenceGroup] = []
  114. # Add seq groups to scheduler.
  115. for i in range(2):
  116. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  117. scheduler.add_seq_group(seq_group)
  118. running.append(seq_group)
  119. assert seq_group.is_prefill()
  120. # Verify the second request is chunked.
  121. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  122. assert set(get_sequence_groups(out)) == set(running)
  123. assert seq_group_meta[0].token_chunk_size == 60
  124. # Verify it is chunked.
  125. assert seq_group_meta[1].token_chunk_size == 4
  126. assert not running[0].is_prefill()
  127. assert running[1].is_prefill()
  128. assert out.num_prefill_groups == 2
  129. assert out.num_batched_tokens == 64
  130. # Only the first seq group has a new token appended.
  131. append_new_token(running[0], 1)
  132. # Add 2 more requests.
  133. for i in range(2, 4):
  134. _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
  135. scheduler.add_seq_group(seq_group)
  136. running.append(seq_group)
  137. # Decoding & chunked prefill & first chunk of 3rd request is scheduled.
  138. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  139. assert len(get_sequence_groups(out)) == 3
  140. # The first one is the first chunked prefill.
  141. assert seq_group_meta[0].token_chunk_size == 7
  142. # The second one is the second new chunked prefill.
  143. assert seq_group_meta[1].token_chunk_size == 56
  144. # The last one is decode.
  145. assert seq_group_meta[2].token_chunk_size == 1
  146. # Two of them are in chunked prefill.
  147. assert out.num_prefill_groups == 2
  148. assert out.num_batched_tokens == 64
  149. # The first 2 requests are now in decodine phase.
  150. append_new_token(running[0], 1)
  151. assert not running[0].is_prefill()
  152. append_new_token(running[1], 1)
  153. assert not running[1].is_prefill()
  154. # The third request is still in prefill stage.
  155. assert running[2].is_prefill()
  156. def test_maximal_decoding():
  157. """Verify decoding requests are prioritized."""
  158. block_size = 4
  159. max_seqs = 2
  160. max_model_len = 8
  161. max_num_batched_tokens = 2
  162. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  163. max_seqs,
  164. max_model_len,
  165. enable_chunked_prefill=True,
  166. is_attention_free=False)
  167. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  168. is_attention_free=False)
  169. cache_config.num_cpu_blocks = 8
  170. cache_config.num_gpu_blocks = 8
  171. scheduler = Scheduler(scheduler_config, cache_config, None)
  172. running: List[SequenceGroup] = []
  173. # Add seq groups to scheduler.
  174. for i in range(2):
  175. _, seq_group = create_dummy_prompt(str(i), prompt_length=2)
  176. scheduler.add_seq_group(seq_group)
  177. running.append(seq_group)
  178. assert seq_group.is_prefill()
  179. # The first prefill is scheduled.
  180. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  181. assert len(get_sequence_groups(out)) == 1
  182. assert seq_group_meta[0].token_chunk_size == 2
  183. assert not running[0].is_prefill()
  184. assert running[1].is_prefill()
  185. assert out.num_prefill_groups == 1
  186. assert out.num_batched_tokens == 2
  187. # Only the first seq group has a new token appended.
  188. append_new_token(running[0], 1)
  189. # Create one more seq_group.
  190. _, seq_group = create_dummy_prompt("3", prompt_length=2)
  191. scheduler.add_seq_group(seq_group)
  192. running.append(seq_group)
  193. assert seq_group.is_prefill()
  194. # The first decoding + second chunk is scheduled.
  195. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  196. assert len(get_sequence_groups(out)) == 2
  197. assert seq_group_meta[0].token_chunk_size == 1
  198. assert seq_group_meta[1].token_chunk_size == 1
  199. assert not running[0].is_prefill()
  200. assert running[1].is_prefill()
  201. assert running[2].is_prefill()
  202. assert out.num_prefill_groups == 1
  203. assert out.num_batched_tokens == 2
  204. append_new_token(running[0], 1)
  205. # Decoding + running prefill is prioritized.
  206. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  207. assert len(get_sequence_groups(out)) == 2
  208. assert seq_group_meta[0].token_chunk_size == 1
  209. assert seq_group_meta[1].token_chunk_size == 1
  210. assert not running[0].is_prefill()
  211. assert not running[1].is_prefill()
  212. assert out.num_prefill_groups == 1
  213. assert out.num_batched_tokens == 2
  214. append_new_token(running[0], 1)
  215. append_new_token(running[1], 1)
  216. # Only decoding is prioritized.
  217. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  218. assert len(get_sequence_groups(out)) == 2
  219. assert seq_group_meta[0].token_chunk_size == 1
  220. assert seq_group_meta[1].token_chunk_size == 1
  221. assert not running[0].is_prefill()
  222. assert not running[1].is_prefill()
  223. assert out.num_prefill_groups == 0
  224. assert out.num_batched_tokens == 2
  225. append_new_token(running[0], 1)
  226. append_new_token(running[1], 1)
  227. # After aborting the decoding request, the fcfs new prefill is prioritized.
  228. scheduler.abort_seq_group(running[0].request_id)
  229. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  230. assert len(get_sequence_groups(out)) == 2
  231. assert seq_group_meta[0].token_chunk_size == 1
  232. assert seq_group_meta[1].token_chunk_size == 1
  233. assert not running[1].is_prefill()
  234. assert running[2].is_prefill()
  235. assert out.num_prefill_groups == 1
  236. assert out.num_batched_tokens == 2
  237. def test_prompt_limit():
  238. """Verify max_num_batched_tokens < max_model_len is possible."""
  239. block_size = 4
  240. max_seqs = 32
  241. max_model_len = 64
  242. max_num_batched_tokens = 32
  243. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  244. max_seqs,
  245. max_model_len,
  246. enable_chunked_prefill=True,
  247. is_attention_free=False)
  248. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  249. is_attention_free=False)
  250. cache_config.num_cpu_blocks = 8
  251. cache_config.num_gpu_blocks = 8
  252. scheduler = Scheduler(scheduler_config, cache_config, None)
  253. running: List[SequenceGroup] = []
  254. _, seq_group = create_dummy_prompt("1", prompt_length=48)
  255. scheduler.add_seq_group(seq_group)
  256. running.append(seq_group)
  257. assert seq_group.is_prefill()
  258. # The prompt length > max_num_batched_tokens should be still scheduled.
  259. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  260. assert len(get_sequence_groups(out)) == 1
  261. assert seq_group_meta[0].token_chunk_size == 32
  262. assert running[0].is_prefill()
  263. assert out.num_prefill_groups == 1
  264. assert out.num_batched_tokens == 32
  265. def test_prompt_limit_exceed():
  266. block_size = 4
  267. max_seqs = 64
  268. max_model_len = 32
  269. max_num_batched_tokens = 64
  270. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  271. max_seqs,
  272. max_model_len,
  273. enable_chunked_prefill=True,
  274. is_attention_free=False)
  275. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  276. is_attention_free=False)
  277. cache_config.num_cpu_blocks = 8
  278. cache_config.num_gpu_blocks = 8
  279. scheduler = Scheduler(scheduler_config, cache_config, None)
  280. running: List[SequenceGroup] = []
  281. _, seq_group = create_dummy_prompt("2", prompt_length=48)
  282. scheduler.add_seq_group(seq_group)
  283. running.append(seq_group)
  284. assert seq_group.is_prefill()
  285. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  286. assert len(out.ignored_seq_groups) == 1
  287. assert out.ignored_seq_groups[0] == seq_group
  288. def test_swap():
  289. """Verify swapping works with chunked prefill requests"""
  290. block_size = 4
  291. max_seqs = 30
  292. max_model_len = 200
  293. max_num_batched_tokens = 30
  294. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  295. max_seqs,
  296. max_model_len,
  297. enable_chunked_prefill=True,
  298. is_attention_free=False)
  299. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  300. is_attention_free=False)
  301. cache_config.num_cpu_blocks = 8
  302. cache_config.num_gpu_blocks = 8
  303. scheduler = Scheduler(scheduler_config, cache_config, None)
  304. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  305. scheduler.add_seq_group(seq_group)
  306. _, out = schedule_and_update_computed_tokens(scheduler)
  307. # The request is chunked.
  308. # prefill scheduled now.
  309. assert len(out.scheduled_seq_groups) == 1
  310. assert out.num_prefill_groups == 1
  311. assert seq_group.is_prefill()
  312. assert out.num_batched_tokens == max_num_batched_tokens
  313. # The last request should be swapped out.
  314. scheduler.block_manager.can_append_slots = MagicMock()
  315. def cannot_append_second_group(seq_group, num_lookahead_slots):
  316. return seq_group.request_id != "1"
  317. scheduler.block_manager.can_append_slots.side_effect = (
  318. cannot_append_second_group)
  319. # The running prefill is now swapped.
  320. _, out = schedule_and_update_computed_tokens(scheduler)
  321. assert len(out.scheduled_seq_groups) == 0
  322. assert out.num_batched_tokens == 0
  323. assert out.blocks_to_swap_out != []
  324. assert out.blocks_to_swap_in == []
  325. # Add 1 more task. Swap should be prioritized over new prefill.
  326. _, seq_group = create_dummy_prompt("2", prompt_length=60)
  327. scheduler.add_seq_group(seq_group)
  328. _, out = schedule_and_update_computed_tokens(scheduler)
  329. assert len(out.scheduled_seq_groups) == 1
  330. # 3 decodes. It is swapped in.
  331. assert out.num_batched_tokens == 30
  332. assert out.blocks_to_swap_in != []
  333. assert out.blocks_to_swap_out == []
  334. def test_running_prefill_prioritized_over_swap():
  335. block_size = 4
  336. max_seqs = 30
  337. max_model_len = 200
  338. max_num_batched_tokens = 30
  339. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  340. max_seqs,
  341. max_model_len,
  342. enable_chunked_prefill=True,
  343. is_attention_free=False)
  344. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  345. is_attention_free=False)
  346. cache_config.num_cpu_blocks = 8
  347. cache_config.num_gpu_blocks = 8
  348. scheduler = Scheduler(scheduler_config, cache_config, None)
  349. _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
  350. scheduler.add_seq_group(seq_group)
  351. _, out = schedule_and_update_computed_tokens(scheduler)
  352. # The request is chunked.
  353. # prefill scheduled now.
  354. assert len(out.scheduled_seq_groups) == 1
  355. assert out.num_prefill_groups == 1
  356. assert seq_group.is_prefill()
  357. assert out.num_batched_tokens == max_num_batched_tokens
  358. # The request should be swapped out.
  359. scheduler.block_manager.can_append_slots = MagicMock()
  360. def cannot_append_second_group(seq_group, num_lookahead_slots):
  361. return seq_group.request_id != "1"
  362. scheduler.block_manager.can_append_slots.side_effect = (
  363. cannot_append_second_group)
  364. # The running prefill is now swapped.
  365. _, out = schedule_and_update_computed_tokens(scheduler)
  366. assert len(out.scheduled_seq_groups) == 0
  367. assert out.num_batched_tokens == 0
  368. assert out.blocks_to_swap_out != []
  369. assert out.blocks_to_swap_in == []
  370. # Add 1 more task. Swap is not possible, so prefill is running.
  371. scheduler.block_manager.can_swap_in = MagicMock()
  372. scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
  373. _, seq_group2 = create_dummy_prompt("2", prompt_length=60)
  374. scheduler.add_seq_group(seq_group2)
  375. _, out = schedule_and_update_computed_tokens(scheduler)
  376. assert len(out.scheduled_seq_groups) == 1
  377. # 3 decodes. It is swapped in.
  378. assert out.num_batched_tokens == 30
  379. assert out.blocks_to_swap_in == []
  380. assert out.blocks_to_swap_out == []
  381. assert out.scheduled_seq_groups[0].seq_group == seq_group2
  382. # Now although swap is possible, running prefill is prioritized.
  383. scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
  384. _, out = schedule_and_update_computed_tokens(scheduler)
  385. assert len(out.scheduled_seq_groups) == 1
  386. # 3 decodes. It is swapped in.
  387. assert out.num_batched_tokens == 30
  388. assert out.blocks_to_swap_in == []
  389. assert out.blocks_to_swap_out == []
  390. assert not seq_group2.is_prefill()
  391. assert out.scheduled_seq_groups[0].seq_group == seq_group2
  392. append_new_token(seq_group2, 1)
  393. # Decoding is prioritized.
  394. _, out = schedule_and_update_computed_tokens(scheduler)
  395. assert len(out.scheduled_seq_groups) == 1
  396. # 3 decodes. It is swapped in.
  397. assert out.num_batched_tokens == 1
  398. assert out.blocks_to_swap_in == []
  399. assert out.blocks_to_swap_out == []
  400. assert not seq_group2.is_prefill()
  401. assert out.scheduled_seq_groups[0].seq_group == seq_group2
  402. append_new_token(seq_group2, 1)
  403. # Since we abort the sequence group, we can finally swap.
  404. scheduler.abort_seq_group(seq_group2.request_id)
  405. _, out = schedule_and_update_computed_tokens(scheduler)
  406. assert len(out.scheduled_seq_groups) == 1
  407. assert out.num_batched_tokens == 30
  408. assert out.blocks_to_swap_in != []
  409. assert out.blocks_to_swap_out == []
  410. def test_chunked_prefill_preempt():
  411. """Verify preempt works with chunked prefill requests"""
  412. block_size = 4
  413. max_seqs = 30
  414. max_model_len = 200
  415. max_num_batched_tokens = 30
  416. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  417. max_seqs,
  418. max_model_len,
  419. enable_chunked_prefill=True,
  420. is_attention_free=False)
  421. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  422. is_attention_free=False)
  423. cache_config.num_cpu_blocks = 8
  424. cache_config.num_gpu_blocks = 8
  425. scheduler = Scheduler(scheduler_config, cache_config, None)
  426. _, seq_group = create_dummy_prompt("1", prompt_length=60)
  427. scheduler.add_seq_group(seq_group)
  428. _, out = schedule_and_update_computed_tokens(scheduler)
  429. # The request is chunked.
  430. # prefill scheduled now.
  431. assert len(out.scheduled_seq_groups) == 1
  432. assert out.num_prefill_groups == 1
  433. assert seq_group.is_prefill()
  434. assert out.num_batched_tokens == max_num_batched_tokens
  435. # The request should be preempted.
  436. scheduler.block_manager.can_append_slots = MagicMock()
  437. def cannot_append_second_group1(seq_group, num_lookahead_slots):
  438. return seq_group.request_id != "1"
  439. scheduler.block_manager.can_append_slots.side_effect = (
  440. cannot_append_second_group1)
  441. # The running prefill is now preempted.
  442. _, out = schedule_and_update_computed_tokens(scheduler)
  443. assert len(out.scheduled_seq_groups) == 0
  444. assert out.num_batched_tokens == 0
  445. assert out.blocks_to_swap_out == []
  446. assert out.blocks_to_swap_in == []
  447. # Make sure we can reschedule preempted request.
  448. _, out = schedule_and_update_computed_tokens(scheduler)
  449. assert len(out.scheduled_seq_groups) == 1
  450. assert out.num_prefill_groups == 1
  451. assert seq_group.is_prefill()
  452. assert out.num_batched_tokens == max_num_batched_tokens
  453. assert seq_group.get_num_uncomputed_tokens() == 30
  454. # We should be able to run prefill twice as it is chunked.
  455. def cannot_append_second_group2(seq_group, num_lookahead_slots):
  456. return True
  457. scheduler.block_manager.can_append_slots.side_effect = (
  458. cannot_append_second_group2)
  459. _, out = schedule_and_update_computed_tokens(scheduler)
  460. assert len(out.scheduled_seq_groups) == 1
  461. assert out.num_prefill_groups == 1
  462. assert not seq_group.is_prefill()
  463. assert out.num_batched_tokens == max_num_batched_tokens
  464. def test_chunked_prefill_max_seqs():
  465. block_size = 4
  466. max_seqs = 2
  467. max_model_len = 80
  468. max_num_batched_tokens = 64
  469. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  470. max_seqs,
  471. max_model_len,
  472. enable_chunked_prefill=True,
  473. is_attention_free=False)
  474. cache_config = CacheConfig(block_size, 1.0, 1, "auto",
  475. is_attention_free=False)
  476. cache_config.num_cpu_blocks = 8
  477. cache_config.num_gpu_blocks = 8
  478. scheduler = Scheduler(scheduler_config, cache_config, None)
  479. running: List[SequenceGroup] = []
  480. _, seq_group = create_dummy_prompt("1", prompt_length=65)
  481. scheduler.add_seq_group(seq_group)
  482. running.append(seq_group)
  483. # The first prefill is chunked.
  484. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  485. assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens
  486. assert len(get_sequence_groups(out)) == 1
  487. # Add new requests.
  488. for i in range(4):
  489. _, seq_group = create_dummy_prompt(str(i), prompt_length=65)
  490. scheduler.add_seq_group(seq_group)
  491. running.append(seq_group)
  492. # Make sure only 2 requests are scheduled.
  493. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  494. assert out.num_batched_tokens == max_num_batched_tokens
  495. assert len(get_sequence_groups(out)) == 2
  496. assert not running[0].is_prefill()
  497. assert running[1].is_prefill()
  498. append_new_token(running[0], 1)
  499. # Although we have enough token budget, we can only schedule max_seqs.
  500. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  501. assert seq_group_meta[0].token_chunk_size == 2
  502. assert seq_group_meta[1].token_chunk_size == 1
  503. assert out.num_batched_tokens == 3
  504. assert len(get_sequence_groups(out)) == max_seqs
  505. assert not running[0].is_prefill()
  506. assert not running[1].is_prefill()
  507. def test_perfix_caching():
  508. """Verify allocating full blocks when prefix caching is enabled."""
  509. block_size = 4
  510. max_seqs = 10
  511. max_model_len = 80
  512. max_num_batched_tokens = 64
  513. scheduler_config = SchedulerConfig(max_num_batched_tokens,
  514. max_seqs,
  515. max_model_len,
  516. enable_chunked_prefill=True)
  517. cache_config = CacheConfig(block_size,
  518. 1.0,
  519. 1,
  520. "auto",
  521. enable_prefix_caching=True)
  522. cache_config.num_cpu_blocks = 0
  523. cache_config.num_gpu_blocks = 32
  524. scheduler = Scheduler(scheduler_config, cache_config, None)
  525. running: List[SequenceGroup] = []
  526. # Add seq groups to scheduler.
  527. for i in range(2):
  528. _, seq_group = create_dummy_prompt(str(i),
  529. block_size=block_size,
  530. prompt_length=50)
  531. scheduler.add_seq_group(seq_group)
  532. running.append(seq_group)
  533. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
  534. assert set(get_sequence_groups(out)) == set(running)
  535. assert seq_group_meta[0].token_chunk_size == 50
  536. # Verify it is chunked. Note that although the budget is 64-50=14,
  537. # we only allocate full blocks for prefix caching, so only 4*(14//4)=12
  538. # tokens are allocated.
  539. assert seq_group_meta[1].token_chunk_size == 12
  540. assert out.num_prefill_groups == 2
  541. assert out.num_batched_tokens == 62