test_multi_step_worker.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. import random
  2. from typing import Dict, List
  3. from unittest.mock import MagicMock
  4. import pytest
  5. import torch
  6. from aphrodite.common.sequence import (ExecuteModelRequest, Logprob,
  7. SamplerOutput)
  8. from aphrodite.modeling.utils import set_random_seed
  9. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  10. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  11. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  12. from aphrodite.task_handler.worker import Worker
  13. from .utils import (assert_logprobs_dict_allclose, create_batch,
  14. create_seq_group_metadata_from_prompts, create_worker,
  15. patch_execute_model_with_seeds, zero_kv_cache)
  16. @pytest.mark.parametrize('num_steps', list(range(1, 17)))
  17. def test_assert_enough_kv_space(num_steps: int):
  18. """Test that the multi step worker checks for sufficient space in the KV
  19. cache. It should throw if it cannot run all the steps.
  20. """
  21. block_size = 16
  22. num_gpu_blocks = 2048 // block_size
  23. prompts = [
  24. list(range(block_size * 3)),
  25. list(range(block_size * 2)),
  26. ]
  27. prev_output_tokens = [
  28. list(range(block_size * 1)),
  29. list(range(block_size * 2)),
  30. ]
  31. final_prompt_lens = [
  32. len(prompt + output) + num_steps
  33. for prompt, output in zip(prompts, prev_output_tokens)
  34. ]
  35. inputs = create_seq_group_metadata_from_prompts(
  36. prompts,
  37. num_gpu_blocks,
  38. block_size,
  39. final_prompt_lens,
  40. continuations=prev_output_tokens)
  41. assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
  42. worker = MagicMock()
  43. worker.model_runner.block_size = block_size
  44. for seq_group_metadata in inputs:
  45. original_block_tables = seq_group_metadata.block_tables
  46. # No exception.
  47. assert_enough_kv_space(worker, inputs, num_steps)
  48. seq_group_metadata.block_tables = {
  49. seq_id: []
  50. for seq_id, physical_blocks in original_block_tables.items()
  51. }
  52. # Expect exception.
  53. with pytest.raises(ValueError,
  54. match='times but found insufficient KV space for'):
  55. assert_enough_kv_space(worker, inputs, num_steps)
  56. seq_group_metadata.block_tables = original_block_tables
  57. @torch.inference_mode()
  58. def test_same_output_for_single_step():
  59. """Verify the multi step worker produces the same output as the normal
  60. worker for num_steps=1.
  61. """
  62. seed = 100
  63. model_name = 'JackFram/llama-68m'
  64. block_size = 32
  65. num_gpu_blocks = 2048 // block_size
  66. multi_step_worker = create_worker(
  67. MultiStepWorker,
  68. model_name,
  69. block_size,
  70. num_gpu_blocks,
  71. seed,
  72. model_runner_cls=TP1DraftModelRunner,
  73. )
  74. worker = create_worker(
  75. Worker,
  76. model_name,
  77. block_size,
  78. num_gpu_blocks,
  79. seed,
  80. )
  81. # multi_step_worker.model_runner = worker.model_runner
  82. # multi_step_worker.cache_engine = worker.cache_engine
  83. num_steps = 1
  84. prompts = [
  85. [1, 2, 3, 4, 5],
  86. [6, 7, 8, 9, 10],
  87. ]
  88. final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
  89. multi_step_seq_group = create_seq_group_metadata_from_prompts(
  90. prompts,
  91. num_gpu_blocks,
  92. block_size,
  93. final_prompt_lens=final_prompt_lens)
  94. zero_kv_cache(multi_step_worker.cache_engine)
  95. set_random_seed(seed)
  96. actual_output, _ = multi_step_worker.sampler_output(
  97. execute_model_req=ExecuteModelRequest(
  98. seq_group_metadata_list=multi_step_seq_group),
  99. sample_len=num_steps,
  100. seq_ids_with_bonus_token_in_last_step=set())
  101. assert len(actual_output) == num_steps
  102. actual_output = actual_output[0]
  103. single_step_seq_group = create_seq_group_metadata_from_prompts(
  104. prompts,
  105. num_gpu_blocks,
  106. block_size,
  107. final_prompt_lens=final_prompt_lens)
  108. zero_kv_cache(worker.cache_engine)
  109. set_random_seed(seed)
  110. expected_output = worker.execute_model(
  111. execute_model_req=ExecuteModelRequest(
  112. seq_group_metadata_list=single_step_seq_group))[0]
  113. actual_token_ids = [
  114. output.samples[0].output_token for output in actual_output
  115. ]
  116. actual_logprobs = [output.samples[0].logprobs for output in actual_output]
  117. expected_token_ids = [
  118. output.samples[0].output_token for output in expected_output
  119. ]
  120. expected_logprobs = [
  121. output.samples[0].logprobs for output in expected_output
  122. ]
  123. assert actual_token_ids == expected_token_ids
  124. print(f'{actual_logprobs=}')
  125. print(f'{expected_logprobs=}')
  126. assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
  127. @torch.inference_mode()
  128. def test_same_output_for_multi_step():
  129. """Verify the multi-step worker produces the same output as the normal
  130. worker when num_steps > 1. This test runs the multi-step worker once, and
  131. then runs the worker num_steps times, and compares the output.
  132. """
  133. seed = 100
  134. model_name = 'JackFram/llama-68m'
  135. block_size = 16
  136. num_gpu_blocks = 2048 // block_size
  137. multi_step_worker = create_worker(
  138. MultiStepWorker,
  139. model_name,
  140. block_size,
  141. num_gpu_blocks,
  142. seed,
  143. model_runner_cls=TP1DraftModelRunner,
  144. )
  145. worker = create_worker(
  146. Worker,
  147. model_name,
  148. block_size,
  149. num_gpu_blocks,
  150. seed,
  151. )
  152. # Make sure we go over the block boundary.
  153. num_steps = block_size + 1
  154. random.seed(seed)
  155. prompts = [[
  156. random.randint(0, 1000) for _ in range(random.randint(10, 20))
  157. ] for _ in range(10)]
  158. final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
  159. rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
  160. multi_step_worker.execute_model = patch_execute_model_with_seeds(
  161. multi_step_worker, rand_seeds)
  162. worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
  163. continuations = [[1] for _ in prompts]
  164. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  165. prompts,
  166. num_gpu_blocks,
  167. block_size,
  168. continuations=continuations,
  169. final_prompt_lens=final_prompt_lens)
  170. # Run multi-step.
  171. zero_kv_cache(multi_step_worker.cache_engine)
  172. set_random_seed(seed)
  173. multi_step_output, _ = multi_step_worker.sampler_output(
  174. execute_model_req=ExecuteModelRequest(
  175. seq_group_metadata_list=seq_group_metadata_list),
  176. sample_len=num_steps,
  177. seq_ids_with_bonus_token_in_last_step=set())
  178. # Run single-step repeatedly.
  179. zero_kv_cache(worker.cache_engine)
  180. single_step_output: List[SamplerOutput] = []
  181. continuations = [[1] for _ in prompts]
  182. set_random_seed(seed)
  183. for _ in multi_step_output:
  184. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  185. prompts,
  186. num_gpu_blocks,
  187. block_size,
  188. continuations=continuations,
  189. final_prompt_lens=final_prompt_lens)
  190. single_step_output.extend(
  191. worker.execute_model(execute_model_req=ExecuteModelRequest(
  192. seq_group_metadata_list=seq_group_metadata_list)))
  193. # Append output tokens to new sequence data.
  194. for i, seq_group_output in enumerate(single_step_output[-1]):
  195. continuations[i].append(seq_group_output.samples[0].output_token)
  196. # Get token ids and logprobs for comparison.
  197. multi_step_output_logprobs: List[List[Dict[int,
  198. Logprob]]] = [[]
  199. for _ in prompts]
  200. single_step_output_logprobs: List[List[Dict[int,
  201. Logprob]]] = [[]
  202. for _ in prompts]
  203. multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
  204. single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
  205. for i, _ in enumerate(prompts):
  206. for multi_step, single_step in zip(multi_step_output,
  207. single_step_output):
  208. multi_step_output_token_ids[i].append(
  209. multi_step[i].samples[0].output_token)
  210. single_step_output_token_ids[i].append(
  211. single_step[i].samples[0].output_token)
  212. multi_step_output_logprobs[i].append(
  213. multi_step[i].samples[0].logprobs)
  214. single_step_output_logprobs[i].append(
  215. single_step[i].samples[0].logprobs)
  216. # Print per-sequence token ids
  217. for i, (multi_step_tokens, single_step_tokens) in enumerate(
  218. zip(multi_step_output_token_ids, single_step_output_token_ids)):
  219. print(f'{i=} {multi_step_tokens=}')
  220. print(f'{i=} {single_step_tokens=}')
  221. print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
  222. # Assert token ids are equal.
  223. for multi_step_tokens, single_step_tokens in zip(
  224. multi_step_output_token_ids, single_step_output_token_ids):
  225. assert multi_step_tokens == single_step_tokens
  226. # Assert logprobs are equal.
  227. for multi_step_logprobs, single_step_logprobs in zip(
  228. multi_step_output_logprobs, single_step_output_logprobs):
  229. assert_logprobs_dict_allclose(multi_step_logprobs,
  230. single_step_logprobs)
  231. @torch.inference_mode()
  232. def test_multi_step_with_batch_expansion_correct_output():
  233. """
  234. In this test we verify that the MultiStepWorker is able to handle bonus
  235. tokens correctly. The test verifies that if a sequence has a
  236. bonus token then the MultiStepWorker is able to expand the batch by adding
  237. new sequences corresponding to the sequences with bonus tokens. The
  238. expanded batch is then used for predicting the next tokens.
  239. """
  240. seed = 100
  241. model_name = 'JackFram/llama-68m'
  242. block_size = 16
  243. num_gpu_blocks = 2048 // block_size
  244. batch_size = 128
  245. multi_step_worker = create_worker(
  246. MultiStepWorker,
  247. model_name,
  248. block_size,
  249. num_gpu_blocks,
  250. seed,
  251. model_runner_cls=TP1DraftModelRunner,
  252. )
  253. worker = create_worker(
  254. Worker,
  255. model_name,
  256. block_size,
  257. num_gpu_blocks,
  258. seed,
  259. )
  260. random.seed(seed)
  261. prompts = [[0] for _ in range(batch_size)]
  262. num_steps = 2
  263. final_prompt_lens = [(num_steps + 1) for prompt in prompts]
  264. rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
  265. multi_step_worker.execute_model = patch_execute_model_with_seeds(
  266. multi_step_worker, rand_seeds)
  267. worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
  268. # Create the test continuations
  269. continuations = [[random.randint(0, 1000)] for _ in prompts]
  270. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  271. prompts,
  272. num_gpu_blocks,
  273. block_size,
  274. continuations=continuations,
  275. final_prompt_lens=final_prompt_lens)
  276. # Run single-step twice to generate 2 tokens. This
  277. # will simulate the bonus token case with the second token
  278. # being the bonus token.
  279. zero_kv_cache(worker.cache_engine)
  280. single_step_output: List[SamplerOutput] = []
  281. set_random_seed(seed)
  282. for _ in range(num_steps):
  283. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  284. prompts,
  285. num_gpu_blocks,
  286. block_size,
  287. continuations=continuations,
  288. final_prompt_lens=final_prompt_lens)
  289. single_step_output.extend(
  290. worker.execute_model(execute_model_req=ExecuteModelRequest(
  291. seq_group_metadata_list=seq_group_metadata_list)))
  292. # Append output tokens to new sequence data.
  293. for i, seq_group_output in enumerate(single_step_output[-1]):
  294. continuations[i].append(seq_group_output.samples[0].output_token)
  295. # Create continuations for the MultiStepWorker. The continuations have
  296. # 2 tokens in order to simulate the bonus token case.
  297. multi_step_continuations = []
  298. for continuation in continuations:
  299. multi_step_continuations.append(continuation[:2])
  300. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  301. prompts,
  302. num_gpu_blocks,
  303. block_size,
  304. continuations=multi_step_continuations,
  305. final_prompt_lens=final_prompt_lens)
  306. # Run multi-step and verify that the third token prediction is accurate
  307. # for all sequences.
  308. zero_kv_cache(multi_step_worker.cache_engine)
  309. all_seq_ids = {i for i in range(batch_size)}
  310. multi_step_output, _ = multi_step_worker.sampler_output(
  311. execute_model_req=ExecuteModelRequest(
  312. seq_group_metadata_list=seq_group_metadata_list),
  313. sample_len=1,
  314. seq_ids_with_bonus_token_in_last_step=all_seq_ids)
  315. for index, output in enumerate(multi_step_output[-1].outputs):
  316. assert (continuations[index][-1] == output.samples[0].output_token)
  317. @torch.inference_mode()
  318. def test_multi_step_with_batch_expansion_incorrect_output():
  319. """
  320. Tests the MultiStepWorker's ability to handle batch expansion with bonus
  321. tokens in a negative case scenario. This test provides the MultiStepWorker
  322. with a batch containing sequences with bonus tokens but specifies the
  323. sequence IDs with bonus tokens incorrectly. The test verifies that the
  324. MultiStepWorker generates correct tokens for the sequences where the
  325. sequence ID is specified correctly and incorrect tokens for those where
  326. the sequence ID is specified incorrectly.
  327. """
  328. seed = 100
  329. model_name = 'JackFram/llama-68m'
  330. block_size = 16
  331. num_gpu_blocks = 2048 // block_size
  332. batch_size = 128
  333. multi_step_worker = create_worker(
  334. MultiStepWorker,
  335. model_name,
  336. block_size,
  337. num_gpu_blocks,
  338. seed,
  339. model_runner_cls=TP1DraftModelRunner,
  340. )
  341. worker = create_worker(
  342. Worker,
  343. model_name,
  344. block_size,
  345. num_gpu_blocks,
  346. seed,
  347. )
  348. random.seed(seed)
  349. prompts = [[0] for _ in range(batch_size)]
  350. num_steps = 2
  351. final_prompt_lens = [(num_steps + 1) for prompt in prompts]
  352. rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
  353. multi_step_worker.execute_model = patch_execute_model_with_seeds(
  354. multi_step_worker, rand_seeds)
  355. worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
  356. # Create the test continuations
  357. continuations = [[random.randint(0, 1000)] for _ in prompts]
  358. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  359. prompts,
  360. num_gpu_blocks,
  361. block_size,
  362. continuations=continuations,
  363. final_prompt_lens=final_prompt_lens)
  364. # Run single-step twice to generate 2 tokens. This
  365. # will simulate the bonus token case with the second token
  366. # being the bonus token.
  367. zero_kv_cache(worker.cache_engine)
  368. single_step_output: List[SamplerOutput] = []
  369. set_random_seed(seed)
  370. for _ in range(num_steps):
  371. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  372. prompts,
  373. num_gpu_blocks,
  374. block_size,
  375. continuations=continuations,
  376. final_prompt_lens=final_prompt_lens)
  377. single_step_output.extend(
  378. worker.execute_model(execute_model_req=ExecuteModelRequest(
  379. seq_group_metadata_list=seq_group_metadata_list)))
  380. # Append output tokens to new sequence data.
  381. for i, seq_group_output in enumerate(single_step_output[-1]):
  382. continuations[i].append(seq_group_output.samples[0].output_token)
  383. # Create continuations for the MultiStepWorker. The continuations have
  384. # 2 tokens in order to simulate the bonus token case.
  385. multi_step_continuations = []
  386. for continuation in continuations:
  387. multi_step_continuations.append(continuation[:2])
  388. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  389. prompts,
  390. num_gpu_blocks,
  391. block_size,
  392. continuations=multi_step_continuations,
  393. final_prompt_lens=final_prompt_lens)
  394. # Run multi-step. In this run INCORRECTLY specify that only the odd number
  395. # sequences have bonus tokens. Verify that with this setting the third token
  396. # prediction is accurate only for the odd numbered sequences. Also verify
  397. # that the prediction might be wrong for some of the even numbered
  398. # sequences.
  399. zero_kv_cache(multi_step_worker.cache_engine)
  400. set_random_seed(seed)
  401. odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
  402. multi_step_output, _ = multi_step_worker.sampler_output(
  403. execute_model_req=ExecuteModelRequest(
  404. seq_group_metadata_list=seq_group_metadata_list),
  405. sample_len=1,
  406. seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
  407. num_mismatch = 0
  408. for index, output in enumerate(multi_step_output[-1].outputs):
  409. if (index % 2) != 0:
  410. assert (continuations[index][-1] == output.samples[0].output_token)
  411. elif (continuations[index][-1] != output.samples[0].output_token):
  412. num_mismatch += 1
  413. # The prediction is accurate for some of the sequences even without proper
  414. # handling of the bonus tokens. Hence verify that the number of sequences
  415. # for which there is a mismatch is > 0.
  416. assert (num_mismatch > 0)
  417. @torch.inference_mode()
  418. def test_draft_proposals_full_speculation_len():
  419. """Verify Top1Proposer correctly handles case where all sequences
  420. can speculate.
  421. """
  422. k = 10
  423. batch_size = 32
  424. vocab_size = 32_000
  425. device = 'cuda:0'
  426. draft_worker = MagicMock()
  427. proposer = Top1Proposer(
  428. worker=draft_worker,
  429. device=device,
  430. vocab_size=vocab_size,
  431. max_proposal_len=2048,
  432. )
  433. draft_worker.sampler_output.return_value = [
  434. SamplerOutput(
  435. outputs=[],
  436. sampled_token_probs=torch.rand(batch_size,
  437. vocab_size,
  438. device=device,
  439. dtype=torch.float32),
  440. logprobs=torch.rand(batch_size,
  441. vocab_size,
  442. device=device,
  443. dtype=torch.float32),
  444. sampled_token_ids=torch.randint(low=0,
  445. high=vocab_size,
  446. size=(batch_size, ),
  447. device=device,
  448. dtype=torch.long),
  449. ) for _ in range(k)
  450. ], True
  451. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  452. proposals = proposer.get_spec_proposals(
  453. execute_model_req=ExecuteModelRequest(
  454. seq_group_metadata_list=seq_group_metadata_list,
  455. num_lookahead_slots=k),
  456. seq_ids_with_bonus_token_in_last_step=set())
  457. assert torch.is_tensor(proposals.proposal_token_ids)
  458. assert torch.is_tensor(proposals.proposal_probs)
  459. assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
  460. assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
  461. assert proposals.proposal_lens.shape == torch.Size([batch_size])
  462. assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
  463. @torch.inference_mode()
  464. def test_draft_proposals_no_speculations():
  465. """Verify Top1Proposer correctly handles case where no sequences
  466. can speculate.
  467. """
  468. k = 10
  469. batch_size = 32
  470. vocab_size = 32_000
  471. device = 'cuda:0'
  472. prompt_len = 10
  473. draft_worker = MagicMock()
  474. proposer = Top1Proposer(
  475. worker=draft_worker,
  476. device=device,
  477. vocab_size=vocab_size,
  478. max_proposal_len=prompt_len + k - 1,
  479. )
  480. seq_group_metadata_list, _, _ = create_batch(batch_size,
  481. k,
  482. prompt_len=prompt_len)
  483. proposals = proposer.get_spec_proposals(
  484. execute_model_req=ExecuteModelRequest(
  485. seq_group_metadata_list=seq_group_metadata_list,
  486. num_lookahead_slots=k),
  487. seq_ids_with_bonus_token_in_last_step=set())
  488. assert torch.is_tensor(proposals.proposal_token_ids)
  489. assert torch.is_tensor(proposals.proposal_probs)
  490. assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
  491. assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
  492. assert proposals.proposal_lens.shape == torch.Size([batch_size])
  493. assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
  494. @torch.inference_mode()
  495. def test_draft_proposals_mixed_k():
  496. """Verify Top1Proposer correctly handles case some sequences can
  497. speculate and some can't.
  498. """
  499. k = 10
  500. batch_size = 32
  501. vocab_size = 32_000
  502. device = 'cuda:0'
  503. small_prompt_len = 5
  504. long_prompt_len = 10
  505. prev_output_token_len = 20
  506. expected_num_proposal_seqs = 6
  507. expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
  508. prompt_len = [
  509. small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
  510. ] + [long_prompt_len
  511. for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
  512. draft_worker = MagicMock()
  513. proposer = Top1Proposer(
  514. worker=draft_worker,
  515. device=device,
  516. vocab_size=vocab_size,
  517. max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
  518. )
  519. draft_worker.sampler_output.return_value = [
  520. SamplerOutput(
  521. outputs=[],
  522. sampled_token_probs=torch.rand(expected_num_proposal_seqs,
  523. vocab_size,
  524. device=device,
  525. dtype=torch.float32),
  526. logprobs=torch.rand(expected_num_proposal_seqs,
  527. vocab_size,
  528. device=device,
  529. dtype=torch.float32),
  530. sampled_token_ids=torch.randint(
  531. low=0,
  532. high=vocab_size,
  533. size=(expected_num_proposal_seqs, ),
  534. device=device,
  535. dtype=torch.long),
  536. ) for _ in range(k)
  537. ], True
  538. seq_group_metadata_list, _, _ = create_batch(
  539. batch_size,
  540. k,
  541. prompt_len=prompt_len,
  542. prev_output_token_len=prev_output_token_len,
  543. )
  544. proposals = proposer.get_spec_proposals(
  545. execute_model_req=ExecuteModelRequest(
  546. seq_group_metadata_list=seq_group_metadata_list,
  547. num_lookahead_slots=k),
  548. seq_ids_with_bonus_token_in_last_step=set())
  549. assert torch.is_tensor(proposals.proposal_token_ids)
  550. assert torch.is_tensor(proposals.proposal_probs)
  551. assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
  552. assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
  553. assert proposals.proposal_lens.shape == torch.Size([batch_size])
  554. assert proposals.proposal_lens.tolist() == [
  555. k for _ in range(expected_num_proposal_seqs - 1)
  556. ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
  557. @torch.inference_mode()
  558. def test_use_draft_model_runner_advance_step():
  559. """Verify that draft model runner triggers advance step
  560. when applicable.
  561. """
  562. seed = 100
  563. model_name = 'JackFram/llama-68m'
  564. k = 5
  565. batch_size = 32
  566. block_size = 32
  567. num_gpu_blocks = 2048 // block_size
  568. worker = create_worker(
  569. MultiStepWorker,
  570. model_name,
  571. block_size,
  572. num_gpu_blocks,
  573. seed,
  574. model_runner_cls=TP1DraftModelRunner,
  575. )
  576. # Mock "_gpu_advance_step" to raise an exception when called.
  577. exception_secret = "artificial stop"
  578. worker.model_runner._gpu_advance_step = MagicMock()
  579. worker.model_runner._gpu_advance_step.side_effect = ValueError(
  580. exception_secret)
  581. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  582. # Fallback (should not call) when num_steps=1.
  583. execute_model_req = ExecuteModelRequest(
  584. seq_group_metadata_list=seq_group_metadata_list,
  585. num_lookahead_slots=k,
  586. num_steps=1)
  587. worker.execute_model(execute_model_req=execute_model_req)
  588. # Expect exception if _gpu_advance_step is called.
  589. execute_model_req = ExecuteModelRequest(
  590. seq_group_metadata_list=seq_group_metadata_list,
  591. num_lookahead_slots=k,
  592. num_steps=k)
  593. with pytest.raises(ValueError, match=exception_secret):
  594. worker.execute_model(execute_model_req=execute_model_req)
  595. call_args_list = worker.model_runner._gpu_advance_step.call_args_list
  596. assert len(call_args_list) == 1