test_multi_step_worker.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. import random
  2. from typing import Dict, List
  3. from unittest.mock import MagicMock
  4. import pytest
  5. import torch
  6. from aphrodite.common.sequence import (ExecuteModelRequest, HiddenStates,
  7. Logprob, get_all_seq_ids)
  8. from aphrodite.modeling.layers.sampler import SamplerOutput
  9. from aphrodite.modeling.utils import set_random_seed
  10. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  11. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  12. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  13. from aphrodite.worker.worker import Worker
  14. from .utils import (assert_logprobs_dict_allclose, create_batch,
  15. create_seq_group_metadata_from_prompts, create_worker,
  16. patch_execute_model_with_seeds, zero_kv_cache)
  17. @pytest.mark.parametrize("num_steps", list(range(1, 17)))
  18. def test_assert_enough_kv_space(num_steps: int):
  19. """Test that the multi step worker checks for sufficient space in the KV
  20. cache. It should throw if it cannot run all the steps.
  21. """
  22. block_size = 16
  23. num_gpu_blocks = 2048 // block_size
  24. prompts = [
  25. list(range(block_size * 3)),
  26. list(range(block_size * 2)),
  27. ]
  28. prev_output_tokens = [
  29. list(range(block_size * 1)),
  30. list(range(block_size * 2)),
  31. ]
  32. final_prompt_lens = [
  33. len(prompt + output) + num_steps
  34. for prompt, output in zip(prompts, prev_output_tokens)
  35. ]
  36. inputs = create_seq_group_metadata_from_prompts(
  37. prompts,
  38. num_gpu_blocks,
  39. block_size,
  40. final_prompt_lens,
  41. continuations=prev_output_tokens,
  42. )
  43. assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
  44. worker = MagicMock()
  45. worker.model_runner.block_size = block_size
  46. for seq_group_metadata in inputs:
  47. original_block_tables = seq_group_metadata.block_tables
  48. # No exception.
  49. assert_enough_kv_space(worker, inputs, num_steps)
  50. seq_group_metadata.block_tables = {
  51. seq_id: []
  52. for seq_id, physical_blocks in original_block_tables.items()
  53. }
  54. # Expect exception.
  55. with pytest.raises(
  56. ValueError, match="times but found insufficient KV space for"
  57. ):
  58. assert_enough_kv_space(worker, inputs, num_steps)
  59. seq_group_metadata.block_tables = original_block_tables
  60. @torch.inference_mode()
  61. def test_same_output_for_single_step():
  62. """Verify the multi step worker produces the same output as the normal
  63. worker for num_steps=1.
  64. """
  65. seed = 100
  66. model_name = "JackFram/llama-68m"
  67. block_size = 32
  68. num_gpu_blocks = 2048 // block_size
  69. multi_step_worker = create_worker(
  70. MultiStepWorker,
  71. model_name,
  72. block_size,
  73. num_gpu_blocks,
  74. seed,
  75. model_runner_cls=TP1DraftModelRunner,
  76. )
  77. worker = create_worker(
  78. Worker,
  79. model_name,
  80. block_size,
  81. num_gpu_blocks,
  82. seed,
  83. )
  84. # multi_step_worker.model_runner = worker.model_runner
  85. # multi_step_worker.cache_engine = worker.cache_engine
  86. num_steps = 1
  87. prompts = [
  88. [1, 2, 3, 4, 5],
  89. [6, 7, 8, 9, 10],
  90. ]
  91. final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
  92. multi_step_seq_group = create_seq_group_metadata_from_prompts(
  93. prompts, num_gpu_blocks, block_size, final_prompt_lens=final_prompt_lens
  94. )
  95. zero_kv_cache(multi_step_worker.cache_engine)
  96. set_random_seed(seed)
  97. actual_output, _ = multi_step_worker.sampler_output(
  98. execute_model_req=ExecuteModelRequest(
  99. seq_group_metadata_list=multi_step_seq_group
  100. ),
  101. sample_len=num_steps,
  102. seq_ids_with_bonus_token_in_last_step=set(),
  103. )
  104. assert len(actual_output) == num_steps
  105. actual_output = actual_output[0]
  106. single_step_seq_group = create_seq_group_metadata_from_prompts(
  107. prompts, num_gpu_blocks, block_size, final_prompt_lens=final_prompt_lens
  108. )
  109. zero_kv_cache(worker.cache_engine)
  110. set_random_seed(seed)
  111. expected_output = worker.execute_model(
  112. execute_model_req=ExecuteModelRequest(
  113. seq_group_metadata_list=single_step_seq_group
  114. )
  115. )[0]
  116. actual_token_ids = [
  117. output.samples[0].output_token for output in actual_output
  118. ]
  119. actual_logprobs = [output.samples[0].logprobs for output in actual_output]
  120. expected_token_ids = [
  121. output.samples[0].output_token for output in expected_output
  122. ]
  123. expected_logprobs = [
  124. output.samples[0].logprobs for output in expected_output
  125. ]
  126. assert actual_token_ids == expected_token_ids
  127. print(f"{actual_logprobs=}")
  128. print(f"{expected_logprobs=}")
  129. assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
  130. @torch.inference_mode()
  131. def test_same_output_for_multi_step():
  132. """Verify the multi-step worker produces the same output as the normal
  133. worker when num_steps > 1. This test runs the multi-step worker once, and
  134. then runs the worker num_steps times, and compares the output.
  135. """
  136. seed = 100
  137. model_name = "JackFram/llama-68m"
  138. block_size = 16
  139. num_gpu_blocks = 2048 // block_size
  140. multi_step_worker = create_worker(
  141. MultiStepWorker,
  142. model_name,
  143. block_size,
  144. num_gpu_blocks,
  145. seed,
  146. model_runner_cls=TP1DraftModelRunner,
  147. )
  148. worker = create_worker(
  149. Worker,
  150. model_name,
  151. block_size,
  152. num_gpu_blocks,
  153. seed,
  154. )
  155. # Make sure we go over the block boundary.
  156. num_steps = block_size + 1
  157. random.seed(seed)
  158. prompts = [
  159. [random.randint(0, 1000) for _ in range(random.randint(10, 20))]
  160. for _ in range(10)
  161. ]
  162. final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
  163. rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
  164. multi_step_worker.execute_model = patch_execute_model_with_seeds(
  165. multi_step_worker, rand_seeds
  166. )
  167. worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
  168. continuations = [[1] for _ in prompts]
  169. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  170. prompts,
  171. num_gpu_blocks,
  172. block_size,
  173. continuations=continuations,
  174. final_prompt_lens=final_prompt_lens,
  175. )
  176. # Run multi-step.
  177. zero_kv_cache(multi_step_worker.cache_engine)
  178. set_random_seed(seed)
  179. multi_step_output, _ = multi_step_worker.sampler_output(
  180. execute_model_req=ExecuteModelRequest(
  181. seq_group_metadata_list=seq_group_metadata_list
  182. ),
  183. sample_len=num_steps,
  184. seq_ids_with_bonus_token_in_last_step=set(),
  185. )
  186. # Run single-step repeatedly.
  187. zero_kv_cache(worker.cache_engine)
  188. single_step_output: List[SamplerOutput] = []
  189. continuations = [[1] for _ in prompts]
  190. set_random_seed(seed)
  191. for _ in multi_step_output:
  192. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  193. prompts,
  194. num_gpu_blocks,
  195. block_size,
  196. continuations=continuations,
  197. final_prompt_lens=final_prompt_lens,
  198. )
  199. single_step_output.extend(
  200. worker.execute_model(
  201. execute_model_req=ExecuteModelRequest(
  202. seq_group_metadata_list=seq_group_metadata_list
  203. )
  204. )
  205. )
  206. # Append output tokens to new sequence data.
  207. for i, seq_group_output in enumerate(single_step_output[-1]):
  208. continuations[i].append(seq_group_output.samples[0].output_token)
  209. # Get token ids and logprobs for comparison.
  210. multi_step_output_logprobs: List[List[Dict[int, Logprob]]] = [
  211. [] for _ in prompts
  212. ]
  213. single_step_output_logprobs: List[List[Dict[int, Logprob]]] = [
  214. [] for _ in prompts
  215. ]
  216. multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
  217. single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
  218. for i, _ in enumerate(prompts):
  219. for multi_step, single_step in zip(
  220. multi_step_output, single_step_output
  221. ):
  222. multi_step_output_token_ids[i].append(
  223. multi_step[i].samples[0].output_token
  224. )
  225. single_step_output_token_ids[i].append(
  226. single_step[i].samples[0].output_token
  227. )
  228. multi_step_output_logprobs[i].append(
  229. multi_step[i].samples[0].logprobs
  230. )
  231. single_step_output_logprobs[i].append(
  232. single_step[i].samples[0].logprobs
  233. )
  234. # Print per-sequence token ids
  235. for i, (multi_step_tokens, single_step_tokens) in enumerate(
  236. zip(multi_step_output_token_ids, single_step_output_token_ids)
  237. ):
  238. print(f"{i=} {multi_step_tokens=}")
  239. print(f"{i=} {single_step_tokens=}")
  240. print(f"{i=} equal {multi_step_tokens == single_step_tokens}")
  241. # Assert token ids are equal.
  242. for multi_step_tokens, single_step_tokens in zip(
  243. multi_step_output_token_ids, single_step_output_token_ids
  244. ):
  245. assert multi_step_tokens == single_step_tokens
  246. # Assert logprobs are equal.
  247. for multi_step_logprobs, single_step_logprobs in zip(
  248. multi_step_output_logprobs, single_step_output_logprobs
  249. ):
  250. assert_logprobs_dict_allclose(multi_step_logprobs, single_step_logprobs)
  251. @torch.inference_mode()
  252. def test_multi_step_with_batch_expansion_correct_output():
  253. """
  254. In this test we verify that the MultiStepWorker is able to handle bonus
  255. tokens correctly. The test verifies that if a sequence has a
  256. bonus token then the MultiStepWorker is able to expand the batch by adding
  257. new sequences corresponding to the sequences with bonus tokens. The
  258. expanded batch is then used for predicting the next tokens.
  259. """
  260. seed = 100
  261. model_name = "JackFram/llama-68m"
  262. block_size = 16
  263. num_gpu_blocks = 2048 // block_size
  264. batch_size = 128
  265. multi_step_worker = create_worker(
  266. MultiStepWorker,
  267. model_name,
  268. block_size,
  269. num_gpu_blocks,
  270. seed,
  271. model_runner_cls=TP1DraftModelRunner,
  272. )
  273. worker = create_worker(
  274. Worker,
  275. model_name,
  276. block_size,
  277. num_gpu_blocks,
  278. seed,
  279. )
  280. random.seed(seed)
  281. prompts = [[0] for _ in range(batch_size)]
  282. num_steps = 2
  283. final_prompt_lens = [(num_steps + 1) for prompt in prompts]
  284. rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
  285. multi_step_worker.execute_model = patch_execute_model_with_seeds(
  286. multi_step_worker, rand_seeds
  287. )
  288. worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
  289. # Create the test continuations
  290. continuations = [[random.randint(0, 1000)] for _ in prompts]
  291. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  292. prompts,
  293. num_gpu_blocks,
  294. block_size,
  295. continuations=continuations,
  296. final_prompt_lens=final_prompt_lens,
  297. )
  298. # Run single-step twice to generate 2 tokens. This
  299. # will simulate the bonus token case with the second token
  300. # being the bonus token.
  301. zero_kv_cache(worker.cache_engine)
  302. single_step_output: List[SamplerOutput] = []
  303. set_random_seed(seed)
  304. for _ in range(num_steps):
  305. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  306. prompts,
  307. num_gpu_blocks,
  308. block_size,
  309. continuations=continuations,
  310. final_prompt_lens=final_prompt_lens,
  311. )
  312. single_step_output.extend(
  313. worker.execute_model(
  314. execute_model_req=ExecuteModelRequest(
  315. seq_group_metadata_list=seq_group_metadata_list
  316. )
  317. )
  318. )
  319. # Append output tokens to new sequence data.
  320. for i, seq_group_output in enumerate(single_step_output[-1]):
  321. continuations[i].append(seq_group_output.samples[0].output_token)
  322. # Create continuations for the MultiStepWorker. The continuations have
  323. # 2 tokens in order to simulate the bonus token case.
  324. multi_step_continuations = []
  325. for continuation in continuations:
  326. multi_step_continuations.append(continuation[:2])
  327. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  328. prompts,
  329. num_gpu_blocks,
  330. block_size,
  331. continuations=multi_step_continuations,
  332. final_prompt_lens=final_prompt_lens,
  333. )
  334. # Run multi-step and verify that the third token prediction is accurate
  335. # for all sequences.
  336. zero_kv_cache(multi_step_worker.cache_engine)
  337. all_seq_ids = {i for i in range(batch_size)}
  338. multi_step_output, _ = multi_step_worker.sampler_output(
  339. execute_model_req=ExecuteModelRequest(
  340. seq_group_metadata_list=seq_group_metadata_list
  341. ),
  342. sample_len=1,
  343. seq_ids_with_bonus_token_in_last_step=all_seq_ids,
  344. )
  345. for index, output in enumerate(multi_step_output[-1].outputs):
  346. assert continuations[index][-1] == output.samples[0].output_token
  347. @torch.inference_mode()
  348. def test_multi_step_with_batch_expansion_incorrect_output():
  349. """
  350. Tests the MultiStepWorker's ability to handle batch expansion with bonus
  351. tokens in a negative case scenario. This test provides the MultiStepWorker
  352. with a batch containing sequences with bonus tokens but specifies the
  353. sequence IDs with bonus tokens incorrectly. The test verifies that the
  354. MultiStepWorker generates correct tokens for the sequences where the
  355. sequence ID is specified correctly and incorrect tokens for those where
  356. the sequence ID is specified incorrectly.
  357. """
  358. seed = 100
  359. model_name = "JackFram/llama-68m"
  360. block_size = 16
  361. num_gpu_blocks = 2048 // block_size
  362. batch_size = 128
  363. multi_step_worker = create_worker(
  364. MultiStepWorker,
  365. model_name,
  366. block_size,
  367. num_gpu_blocks,
  368. seed,
  369. model_runner_cls=TP1DraftModelRunner,
  370. )
  371. worker = create_worker(
  372. Worker,
  373. model_name,
  374. block_size,
  375. num_gpu_blocks,
  376. seed,
  377. )
  378. random.seed(seed)
  379. prompts = [[0] for _ in range(batch_size)]
  380. num_steps = 2
  381. final_prompt_lens = [(num_steps + 1) for prompt in prompts]
  382. rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
  383. multi_step_worker.execute_model = patch_execute_model_with_seeds(
  384. multi_step_worker, rand_seeds
  385. )
  386. worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
  387. # Create the test continuations
  388. continuations = [[random.randint(0, 1000)] for _ in prompts]
  389. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  390. prompts,
  391. num_gpu_blocks,
  392. block_size,
  393. continuations=continuations,
  394. final_prompt_lens=final_prompt_lens,
  395. )
  396. # Run single-step twice to generate 2 tokens. This
  397. # will simulate the bonus token case with the second token
  398. # being the bonus token.
  399. zero_kv_cache(worker.cache_engine)
  400. single_step_output: List[SamplerOutput] = []
  401. set_random_seed(seed)
  402. for _ in range(num_steps):
  403. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  404. prompts,
  405. num_gpu_blocks,
  406. block_size,
  407. continuations=continuations,
  408. final_prompt_lens=final_prompt_lens,
  409. )
  410. single_step_output.extend(
  411. worker.execute_model(
  412. execute_model_req=ExecuteModelRequest(
  413. seq_group_metadata_list=seq_group_metadata_list
  414. )
  415. )
  416. )
  417. # Append output tokens to new sequence data.
  418. for i, seq_group_output in enumerate(single_step_output[-1]):
  419. continuations[i].append(seq_group_output.samples[0].output_token)
  420. # Create continuations for the MultiStepWorker. The continuations have
  421. # 2 tokens in order to simulate the bonus token case.
  422. multi_step_continuations = []
  423. for continuation in continuations:
  424. multi_step_continuations.append(continuation[:2])
  425. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  426. prompts,
  427. num_gpu_blocks,
  428. block_size,
  429. continuations=multi_step_continuations,
  430. final_prompt_lens=final_prompt_lens,
  431. )
  432. # Run multi-step. In this run INCORRECTLY specify that only the odd number
  433. # sequences have bonus tokens. Verify that with this setting the third token
  434. # prediction is accurate only for the odd numbered sequences. Also verify
  435. # that the prediction might be wrong for some of the even numbered
  436. # sequences.
  437. zero_kv_cache(multi_step_worker.cache_engine)
  438. set_random_seed(seed)
  439. odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
  440. multi_step_output, _ = multi_step_worker.sampler_output(
  441. execute_model_req=ExecuteModelRequest(
  442. seq_group_metadata_list=seq_group_metadata_list
  443. ),
  444. sample_len=1,
  445. seq_ids_with_bonus_token_in_last_step=odd_seq_ids,
  446. )
  447. num_mismatch = 0
  448. for index, output in enumerate(multi_step_output[-1].outputs):
  449. if (index % 2) != 0:
  450. assert continuations[index][-1] == output.samples[0].output_token
  451. elif continuations[index][-1] != output.samples[0].output_token:
  452. num_mismatch += 1
  453. # The prediction is accurate for some of the sequences even without proper
  454. # handling of the bonus tokens. Hence verify that the number of sequences
  455. # for which there is a mismatch is > 0.
  456. assert num_mismatch > 0
  457. @torch.inference_mode()
  458. def test_draft_proposals_full_speculation_len():
  459. """Verify Top1Proposer correctly handles case where all sequences
  460. can speculate.
  461. """
  462. k = 10
  463. batch_size = 32
  464. vocab_size = 32_000
  465. device = "cuda:0"
  466. draft_worker = MagicMock()
  467. proposer = Top1Proposer(
  468. worker=draft_worker,
  469. device=device,
  470. vocab_size=vocab_size,
  471. max_proposal_len=2048,
  472. )
  473. draft_worker.sampler_output.return_value = (
  474. [
  475. SamplerOutput(
  476. outputs=[],
  477. sampled_token_probs=torch.rand(
  478. batch_size, vocab_size, device=device, dtype=torch.float32
  479. ),
  480. logprobs=torch.rand(
  481. batch_size, vocab_size, device=device, dtype=torch.float32
  482. ),
  483. sampled_token_ids=torch.randint(
  484. low=0,
  485. high=vocab_size,
  486. size=(batch_size,),
  487. device=device,
  488. dtype=torch.long,
  489. ),
  490. )
  491. for _ in range(k)
  492. ],
  493. True,
  494. )
  495. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  496. proposals = proposer.get_spec_proposals(
  497. execute_model_req=ExecuteModelRequest(
  498. seq_group_metadata_list=seq_group_metadata_list,
  499. num_lookahead_slots=k,
  500. ),
  501. seq_ids_with_bonus_token_in_last_step=set(),
  502. )
  503. assert torch.is_tensor(proposals.proposal_token_ids)
  504. assert torch.is_tensor(proposals.proposal_probs)
  505. assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
  506. assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
  507. assert proposals.proposal_lens.shape == torch.Size([batch_size])
  508. assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
  509. @torch.inference_mode()
  510. def test_draft_proposals_no_speculations():
  511. """Verify Top1Proposer correctly handles case where no sequences
  512. can speculate.
  513. """
  514. k = 10
  515. batch_size = 32
  516. vocab_size = 32_000
  517. device = "cuda:0"
  518. prompt_len = 10
  519. draft_worker = MagicMock()
  520. proposer = Top1Proposer(
  521. worker=draft_worker,
  522. device=device,
  523. vocab_size=vocab_size,
  524. max_proposal_len=prompt_len + k - 1,
  525. )
  526. seq_group_metadata_list, _, _ = create_batch(
  527. batch_size, k, prompt_len=prompt_len
  528. )
  529. proposals = proposer.get_spec_proposals(
  530. execute_model_req=ExecuteModelRequest(
  531. seq_group_metadata_list=seq_group_metadata_list,
  532. num_lookahead_slots=k,
  533. ),
  534. seq_ids_with_bonus_token_in_last_step=set(),
  535. )
  536. assert torch.is_tensor(proposals.proposal_token_ids)
  537. assert torch.is_tensor(proposals.proposal_probs)
  538. assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
  539. assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
  540. assert proposals.proposal_lens.shape == torch.Size([batch_size])
  541. assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
  542. @torch.inference_mode()
  543. def test_draft_proposals_mixed_k():
  544. """Verify Top1Proposer correctly handles case some sequences can
  545. speculate and some can't.
  546. """
  547. k = 10
  548. batch_size = 32
  549. vocab_size = 32_000
  550. device = "cuda:0"
  551. small_prompt_len = 5
  552. long_prompt_len = 10
  553. prev_output_token_len = 20
  554. expected_num_proposal_seqs = 6
  555. expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
  556. prompt_len = (
  557. [small_prompt_len for _ in range(expected_num_proposal_seqs - 1)]
  558. + [long_prompt_len for _ in range(expected_num_no_proposal_seqs)]
  559. + [small_prompt_len]
  560. )
  561. draft_worker = MagicMock()
  562. proposer = Top1Proposer(
  563. worker=draft_worker,
  564. device=device,
  565. vocab_size=vocab_size,
  566. max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
  567. )
  568. draft_worker.sampler_output.return_value = (
  569. [
  570. SamplerOutput(
  571. outputs=[],
  572. sampled_token_probs=torch.rand(
  573. expected_num_proposal_seqs,
  574. vocab_size,
  575. device=device,
  576. dtype=torch.float32,
  577. ),
  578. logprobs=torch.rand(
  579. expected_num_proposal_seqs,
  580. vocab_size,
  581. device=device,
  582. dtype=torch.float32,
  583. ),
  584. sampled_token_ids=torch.randint(
  585. low=0,
  586. high=vocab_size,
  587. size=(expected_num_proposal_seqs,),
  588. device=device,
  589. dtype=torch.long,
  590. ),
  591. )
  592. for _ in range(k)
  593. ],
  594. True,
  595. )
  596. seq_group_metadata_list, _, _ = create_batch(
  597. batch_size,
  598. k,
  599. prompt_len=prompt_len,
  600. prev_output_token_len=prev_output_token_len,
  601. )
  602. proposals = proposer.get_spec_proposals(
  603. execute_model_req=ExecuteModelRequest(
  604. seq_group_metadata_list=seq_group_metadata_list,
  605. num_lookahead_slots=k,
  606. ),
  607. seq_ids_with_bonus_token_in_last_step=set(),
  608. )
  609. assert torch.is_tensor(proposals.proposal_token_ids)
  610. assert torch.is_tensor(proposals.proposal_probs)
  611. assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
  612. assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
  613. assert proposals.proposal_lens.shape == torch.Size([batch_size])
  614. assert proposals.proposal_lens.tolist() == [
  615. k for _ in range(expected_num_proposal_seqs - 1)
  616. ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
  617. @torch.inference_mode()
  618. def test_use_draft_model_runner_advance_step():
  619. """Verify that draft model runner triggers advance step
  620. when applicable.
  621. """
  622. seed = 100
  623. model_name = "JackFram/llama-68m"
  624. k = 5
  625. batch_size = 32
  626. block_size = 32
  627. num_gpu_blocks = 2048 // block_size
  628. worker = create_worker(
  629. MultiStepWorker,
  630. model_name,
  631. block_size,
  632. num_gpu_blocks,
  633. seed,
  634. model_runner_cls=TP1DraftModelRunner,
  635. )
  636. # Mock "_gpu_advance_step" to raise an exception when called.
  637. exception_secret = "artificial stop"
  638. worker.model_runner._gpu_advance_step = MagicMock()
  639. worker.model_runner._gpu_advance_step.side_effect = ValueError(
  640. exception_secret
  641. )
  642. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  643. # Fallback (should not call) when num_steps=1.
  644. execute_model_req = ExecuteModelRequest(
  645. seq_group_metadata_list=seq_group_metadata_list,
  646. num_lookahead_slots=k,
  647. num_steps=1,
  648. )
  649. worker.execute_model(execute_model_req=execute_model_req)
  650. # Expect exception if _gpu_advance_step is called.
  651. execute_model_req = ExecuteModelRequest(
  652. seq_group_metadata_list=seq_group_metadata_list,
  653. num_lookahead_slots=k,
  654. num_steps=k,
  655. )
  656. with pytest.raises(ValueError, match=exception_secret):
  657. worker.execute_model(execute_model_req=execute_model_req)
  658. call_args_list = worker.model_runner._gpu_advance_step.call_args_list
  659. assert len(call_args_list) == 1
  660. @torch.inference_mode()
  661. def test_expand_execute_model_request_sync_with_expand_hidden_states():
  662. """
  663. In this test we verify that the logic for expanding the
  664. seq_group_metadata_list remains in sync with the expansion logic of
  665. the HiddenStates in _expand_execute_model_request.
  666. """
  667. k = 5
  668. batch_size = 16
  669. seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
  670. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  671. execute_model_request = ExecuteModelRequest(
  672. seq_group_metadata_list,
  673. previous_hidden_states=HiddenStates(
  674. torch.arange(batch_size),
  675. seq_group_metadata_list,
  676. torch.arange(batch_size, 2 * batch_size),
  677. ),
  678. )
  679. (
  680. expanded_execute_model_request,
  681. orig_seq_group_ids,
  682. ) = MultiStepWorker._expand_execute_model_request(
  683. execute_model_request, seq_with_bonus_token_in_last_step
  684. )
  685. all_seq_ids = torch.tensor(
  686. get_all_seq_ids(expanded_execute_model_request.seq_group_metadata_list)
  687. )
  688. ref_expanded_hidden_states = all_seq_ids + batch_size
  689. ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
  690. assert (ref_expanded_hidden_states == expanded_execute_model_request.
  691. previous_hidden_states.hidden_states).all().item()