test_multistep_correctness.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. """The tests in this file verify end-to-end speculative decoding correctness.
  2. This docstring details important information on the testing methodology.
  3. Most of the tests rely on "greedy equality", where we expect the output of
  4. speculative decoding on a sequence to exactly match the output of normal non-
  5. speculative decoding.
  6. Since speculative decoding with rejection sampling guarantees that the output
  7. distribution matches the target model's output distribution (up to hardware
  8. numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
  9. equality. This gives us good coverage of temp=0.
  10. At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
  11. highest probability in the target distribution are accepted. Therefore, we can
  12. expect greedy equality for the TypicalAcceptanceSampler at temp=0.
  13. For temp>0, we rely on unit tests on the rejection sampler to verify that the
  14. output distribution is the same with spec decode vs. no spec decode (this would
  15. be prohibitively expensive to run with a real model). Similarly, for the
  16. TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
  17. test cases.
  18. NOTE: Speculative decoding's distribution equality requires that the measured
  19. distributions of the target model and proposal model be deterministic given the
  20. same input. aphrodite largely guarantees this.
  21. @cadedaniel has seen cases where the output probabilities of a draft/target
  22. model change slightly with certain batch sizes or prompts, even with Torch
  23. determinism flags set. It is unclear if this is a bug in aphrodite, due to non-
  24. determinism in on-device batched operations, a bug in aphrodite's spec decode
  25. implementation, or the "hardware numerics" limitations. Either way, rejection
  26. sampling ensures the output distribution matches the target model, but it breaks
  27. greedy-equality tests for those batch sizes/prompts.
  28. """
  29. from itertools import cycle
  30. import pytest
  31. from transformers import AutoTokenizer
  32. from aphrodite import SamplingParams
  33. from .conftest import (get_output_from_llm_generator,
  34. run_greedy_equality_correctness_test)
  35. @pytest.mark.parametrize(
  36. "common_llm_kwargs",
  37. [{
  38. # Use a small model for a fast test.
  39. # Note this is repeated in the test body; to initialize a tokenizer.
  40. "model": "JackFram/llama-68m",
  41. # Skip cuda graph recording for fast test.
  42. "enforce_eager": True,
  43. # Required for spec decode.
  44. "use_v2_block_manager": True,
  45. }])
  46. @pytest.mark.parametrize(
  47. "per_test_common_llm_kwargs",
  48. [
  49. {
  50. "speculative_model": "JackFram/llama-68m",
  51. "num_speculative_tokens": 5,
  52. },
  53. {
  54. # Verify the detokenizer assertions in the test work when spec
  55. # decode is disabled.
  56. },
  57. ])
  58. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  59. @pytest.mark.parametrize("batch_size", [1, 32])
  60. @pytest.mark.parametrize("seed", [1])
  61. def test_spec_decode_e2e_with_detokenization(test_llm_generator,
  62. batch_size: int):
  63. """Run generation with speculative decoding on a batch. Verify the engine
  64. generates the correct number of tokens (via ignore_eos=True), and that the
  65. detokenization matches HF transformers.
  66. """
  67. output_len = 32
  68. temperature = 0.0
  69. prompts = [
  70. "Hello, my name is",
  71. "The president of the United States is",
  72. "The capital of France is",
  73. "The future of AI is",
  74. ]
  75. prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
  76. sampling_params = SamplingParams(
  77. max_tokens=output_len,
  78. ignore_eos=True,
  79. temperature=temperature,
  80. )
  81. batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
  82. test_llm_generator, prompts, sampling_params)
  83. # Expect a generation for each prompt in the batch.
  84. assert len(batch_token_ids) == len(prompts)
  85. # Expect each generation to have expected number of tokens (note ignore_eos
  86. # is True).
  87. assert [len(token_ids)
  88. for token_ids in batch_token_ids] == ([output_len] * batch_size)
  89. # Expect detokenized string to match.
  90. tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
  91. for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
  92. expected_tokens = tok.decode(actual_token_ids)
  93. print(f"{actual_token_ids=}")
  94. assert actual_tokens.strip() == expected_tokens.strip()
  95. @pytest.mark.parametrize(
  96. "common_llm_kwargs",
  97. [{
  98. # Use a small model for a fast test.
  99. # Note this is repeated in the test body; to initialize a tokenizer.
  100. "model": "JackFram/llama-68m",
  101. # Skip cuda graph recording for fast test.
  102. "enforce_eager": True,
  103. # Required for spec decode.
  104. "use_v2_block_manager": True,
  105. # Use AsyncLLM engine
  106. "use_async": True,
  107. }])
  108. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  109. @pytest.mark.parametrize("per_test_common_llm_kwargs", [
  110. {
  111. "speculative_model": "JackFram/llama-68m",
  112. "num_speculative_tokens": 5,
  113. },
  114. ])
  115. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  116. @pytest.mark.parametrize("batch_size", [2])
  117. @pytest.mark.parametrize("seed", [1])
  118. def test_spec_decode_e2e_with_async_engine(test_llm_generator,
  119. baseline_llm_generator,
  120. batch_size: int):
  121. """Verify spec decode works well with async LLM engine.
  122. """
  123. run_greedy_equality_correctness_test(baseline_llm_generator,
  124. test_llm_generator,
  125. batch_size,
  126. max_output_len=32,
  127. force_output_len=True)
  128. @pytest.mark.parametrize(
  129. "common_llm_kwargs",
  130. [{
  131. # Skip cuda graph recording for fast test.
  132. "enforce_eager": True,
  133. # Required for spec decode.
  134. "use_v2_block_manager": True,
  135. # Print spec metrics.
  136. "disable_log_stats": False,
  137. }])
  138. @pytest.mark.parametrize(
  139. "per_test_common_llm_kwargs",
  140. [
  141. # Try two different tiny base models.
  142. # Note that one is equal to the draft model, another isn't.
  143. {
  144. "model": "JackFram/llama-68m",
  145. },
  146. {
  147. "model": "JackFram/llama-160m",
  148. },
  149. ])
  150. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  151. @pytest.mark.parametrize("test_llm_kwargs", [
  152. {
  153. "speculative_model": "JackFram/llama-68m",
  154. "num_speculative_tokens": 5,
  155. },
  156. ])
  157. @pytest.mark.parametrize(
  158. "output_len",
  159. [
  160. # Use long output len for the small model test.
  161. 1536,
  162. ])
  163. @pytest.mark.parametrize("batch_size", [1])
  164. @pytest.mark.parametrize("seed", [1])
  165. def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
  166. baseline_llm_generator, test_llm_generator, batch_size: int,
  167. output_len: int):
  168. """Verify greedy equality on a tiny model with batch size of one.
  169. Since this test is cheaper than other e2e correctness tests, we generate
  170. with a higher output_len.
  171. When the draft model is the same as the target model, we further check
  172. whether all speculative tokens are accepted.
  173. """
  174. ensure_all_accepted = test_llm_generator.same_draft_target_model
  175. run_greedy_equality_correctness_test(
  176. baseline_llm_generator,
  177. test_llm_generator,
  178. batch_size,
  179. max_output_len=output_len,
  180. force_output_len=True,
  181. ensure_all_accepted=ensure_all_accepted)
  182. @pytest.mark.parametrize(
  183. "common_llm_kwargs",
  184. [{
  185. # Skip cuda graph recording for fast test.
  186. "enforce_eager": True,
  187. # Required for spec decode.
  188. "use_v2_block_manager": True,
  189. # Print spec metrics.
  190. "disable_log_stats": False,
  191. }])
  192. @pytest.mark.parametrize(
  193. "per_test_common_llm_kwargs",
  194. [
  195. # Try two different tiny base models.
  196. # Note that one is equal to the draft model, another isn't.
  197. {
  198. "model": "JackFram/llama-68m",
  199. },
  200. {
  201. "model": "JackFram/llama-160m",
  202. },
  203. ])
  204. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  205. @pytest.mark.parametrize("test_llm_kwargs", [
  206. {
  207. "speculative_model": "JackFram/llama-68m",
  208. "num_speculative_tokens": 5,
  209. },
  210. ])
  211. @pytest.mark.parametrize(
  212. "output_len",
  213. [
  214. # Use small output len for fast test.
  215. 256,
  216. ])
  217. @pytest.mark.parametrize("batch_size", [64])
  218. @pytest.mark.parametrize("seed", [1])
  219. def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
  220. baseline_llm_generator, test_llm_generator, batch_size: int,
  221. output_len: int):
  222. """Verify greedy equality on a tiny model and large batch size.
  223. """
  224. run_greedy_equality_correctness_test(baseline_llm_generator,
  225. test_llm_generator,
  226. batch_size,
  227. max_output_len=output_len,
  228. force_output_len=True)
  229. @pytest.mark.parametrize(
  230. "common_llm_kwargs",
  231. [{
  232. # Skip cuda graph recording for fast test.
  233. "enforce_eager": True,
  234. # Required for spec decode.
  235. "use_v2_block_manager": True
  236. }])
  237. @pytest.mark.parametrize(
  238. "per_test_common_llm_kwargs",
  239. [
  240. # Try two different tiny base models.
  241. # Note that one is equal to the draft model, another isn't.
  242. {
  243. "model": "JackFram/llama-68m",
  244. },
  245. {
  246. "model": "JackFram/llama-160m",
  247. },
  248. ])
  249. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  250. @pytest.mark.parametrize("test_llm_kwargs", [
  251. {
  252. "speculative_model": "JackFram/llama-68m",
  253. "num_speculative_tokens": 5,
  254. },
  255. ])
  256. @pytest.mark.parametrize("max_output_len", [
  257. 256,
  258. ])
  259. @pytest.mark.parametrize("batch_size", [32])
  260. @pytest.mark.parametrize("seed", [1])
  261. def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
  262. baseline_llm_generator, test_llm_generator, batch_size: int,
  263. max_output_len: int):
  264. """Verify greedy equality on a tiny model, with a large batch size, and when
  265. sampling respects the EOS token.
  266. """
  267. run_greedy_equality_correctness_test(baseline_llm_generator,
  268. test_llm_generator,
  269. batch_size,
  270. max_output_len,
  271. force_output_len=False)
  272. @pytest.mark.parametrize(
  273. "common_llm_kwargs",
  274. [{
  275. # A "real" model (not tiny).
  276. "model": "meta-llama/Llama-2-7b-chat-hf",
  277. # Skip cuda graph recording for fast test.
  278. "enforce_eager": True,
  279. # Required for spec decode.
  280. "use_v2_block_manager": True,
  281. # Print spec metrics.
  282. "disable_log_stats": False,
  283. }])
  284. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  285. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  286. @pytest.mark.parametrize("test_llm_kwargs", [
  287. {
  288. "speculative_model": "JackFram/llama-68m",
  289. "num_speculative_tokens": 5,
  290. },
  291. ])
  292. @pytest.mark.parametrize("batch_size", [1])
  293. @pytest.mark.parametrize(
  294. "output_len",
  295. [
  296. # Use decently long output len for a high quality test.
  297. 256,
  298. ])
  299. @pytest.mark.parametrize("seed", [1])
  300. def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
  301. baseline_llm_generator, test_llm_generator, batch_size: int,
  302. output_len: int):
  303. """Verify greedy equality on a "real" model and batch size of 1. This is
  304. separate from large BS tests to make identifying the source of bugs easier.
  305. """
  306. run_greedy_equality_correctness_test(baseline_llm_generator,
  307. test_llm_generator,
  308. batch_size,
  309. max_output_len=output_len,
  310. force_output_len=True)
  311. @pytest.mark.parametrize(
  312. "common_llm_kwargs",
  313. [{
  314. # A "real" model (not tiny).
  315. "model": "meta-llama/Llama-2-7b-chat-hf",
  316. # Skip cuda graph recording for fast test.
  317. "enforce_eager": True,
  318. # Required for spec decode.
  319. "use_v2_block_manager": True,
  320. # Print spec metrics.
  321. "disable_log_stats": False,
  322. }])
  323. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  324. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  325. @pytest.mark.parametrize("test_llm_kwargs", [
  326. {
  327. "speculative_model": "JackFram/llama-68m",
  328. "num_speculative_tokens": 5,
  329. },
  330. ])
  331. @pytest.mark.parametrize("batch_size", [32])
  332. @pytest.mark.parametrize(
  333. "output_len",
  334. [
  335. # Use smaller output len for fast test.
  336. 64,
  337. ])
  338. @pytest.mark.parametrize("seed", [1])
  339. def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
  340. baseline_llm_generator, test_llm_generator, batch_size: int,
  341. output_len: int):
  342. """Verify greedy equality with a "real" model on a nontrivial batch size.
  343. This is the closest test to a real production workload.
  344. """
  345. run_greedy_equality_correctness_test(baseline_llm_generator,
  346. test_llm_generator,
  347. batch_size,
  348. max_output_len=output_len,
  349. force_output_len=True)
  350. @pytest.mark.parametrize(
  351. "common_llm_kwargs",
  352. [{
  353. "block_size": 8,
  354. # 2 for small prompt, 256//8 for generated.
  355. "num_gpu_blocks_override": 2 + 256 // 8,
  356. "max_model_len": (2 + 256 // 8) * 8,
  357. # Skip cuda graph recording for fast test.
  358. "enforce_eager": True,
  359. # Required for spec decode.
  360. "use_v2_block_manager": True
  361. }])
  362. @pytest.mark.parametrize("per_test_common_llm_kwargs", [
  363. {
  364. "model": "JackFram/llama-160m",
  365. },
  366. ])
  367. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  368. @pytest.mark.parametrize("test_llm_kwargs", [
  369. {
  370. "speculative_model": "JackFram/llama-68m",
  371. "num_speculative_tokens": 5,
  372. },
  373. ])
  374. @pytest.mark.parametrize(
  375. "output_len",
  376. [
  377. # Use small output len for fast test.
  378. 256,
  379. ])
  380. @pytest.mark.parametrize("batch_size", [4])
  381. @pytest.mark.parametrize("seed", [1])
  382. def test_spec_decode_e2e_greedy_correctness_with_preemption(
  383. baseline_llm_generator, test_llm_generator, batch_size: int,
  384. output_len: int):
  385. """Verify greedy equality, even when some sequences are preempted mid-
  386. generation.
  387. """
  388. run_greedy_equality_correctness_test(baseline_llm_generator,
  389. test_llm_generator,
  390. batch_size,
  391. max_output_len=output_len,
  392. force_output_len=True)
  393. @pytest.mark.parametrize(
  394. "common_llm_kwargs",
  395. [{
  396. "model": "JackFram/llama-160m",
  397. # Skip cuda graph recording for fast test.
  398. "enforce_eager": True,
  399. # Required for spec decode.
  400. "use_v2_block_manager": True
  401. }])
  402. @pytest.mark.parametrize(
  403. "per_test_common_llm_kwargs",
  404. [
  405. # As of this writing, aphrodite only compiles with these 3 block sizes
  406. # by default.
  407. {
  408. "block_size": 8,
  409. },
  410. {
  411. "block_size": 16,
  412. },
  413. {
  414. "block_size": 32,
  415. },
  416. ])
  417. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  418. @pytest.mark.parametrize("test_llm_kwargs", [
  419. {
  420. "speculative_model": "JackFram/llama-68m",
  421. "num_speculative_tokens": 5,
  422. },
  423. ])
  424. @pytest.mark.parametrize("batch_size", [2])
  425. @pytest.mark.parametrize(
  426. "output_len",
  427. [
  428. # Use smaller output len for fast test.
  429. 32,
  430. ])
  431. @pytest.mark.parametrize("seed", [1])
  432. def test_spec_decode_different_block_size(baseline_llm_generator,
  433. test_llm_generator, batch_size: int,
  434. output_len: int):
  435. """Verify greedy equality over different block sizes.
  436. """
  437. run_greedy_equality_correctness_test(baseline_llm_generator,
  438. test_llm_generator,
  439. batch_size,
  440. max_output_len=output_len,
  441. force_output_len=True)
  442. @pytest.mark.parametrize(
  443. "common_llm_kwargs",
  444. [{
  445. "model": "JackFram/llama-160m",
  446. # Skip cuda graph recording for fast test.
  447. "enforce_eager": True,
  448. # Required for spec decode.
  449. "use_v2_block_manager": True
  450. }])
  451. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  452. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  453. @pytest.mark.parametrize(
  454. "test_llm_kwargs",
  455. [
  456. {
  457. "speculative_model": "JackFram/llama-68m",
  458. "num_speculative_tokens": 5,
  459. # Artificially limit the draft model max model len; this forces
  460. # aphrodite to skip speculation once the sequences grow beyond
  461. # 32-k tokens.
  462. "speculative_max_model_len": 32,
  463. },
  464. ])
  465. @pytest.mark.parametrize("batch_size", [8])
  466. @pytest.mark.parametrize(
  467. "output_len",
  468. [
  469. # This must be a good bit larger than speculative_max_model_len so that
  470. # we can test the case where all seqs are skipped, but still small to
  471. # ensure fast test.
  472. 64,
  473. ])
  474. @pytest.mark.parametrize("seed", [1])
  475. def test_skip_speculation(baseline_llm_generator, test_llm_generator,
  476. batch_size: int, output_len: int):
  477. """Verify greedy equality when some (or all) sequences skip speculation.
  478. We do this by setting the max model len of the draft model to an
  479. artificially low value, such that when the sequences grow beyond it, they
  480. are skipped in speculative decoding.
  481. """
  482. run_greedy_equality_correctness_test(baseline_llm_generator,
  483. test_llm_generator,
  484. batch_size,
  485. max_output_len=output_len,
  486. force_output_len=True)
  487. @pytest.mark.parametrize(
  488. "common_llm_kwargs",
  489. [{
  490. "model": "JackFram/llama-160m",
  491. # Skip cuda graph recording for fast test.
  492. "enforce_eager": True,
  493. # Required for spec decode.
  494. "use_v2_block_manager": True
  495. }])
  496. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  497. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  498. @pytest.mark.parametrize("test_llm_kwargs", [
  499. {
  500. "speculative_model": "JackFram/llama-68m",
  501. "num_speculative_tokens": 5,
  502. "speculative_disable_by_batch_size": 2,
  503. },
  504. ])
  505. @pytest.mark.parametrize("batch_size", [8])
  506. @pytest.mark.parametrize("output_len", [10])
  507. @pytest.mark.parametrize("seed", [1])
  508. def test_disable_speculation(baseline_llm_generator, test_llm_generator,
  509. batch_size: int, output_len: int):
  510. """Verify greedy equality when all sequences disable speculation.
  511. """
  512. run_greedy_equality_correctness_test(baseline_llm_generator,
  513. test_llm_generator,
  514. batch_size,
  515. max_output_len=output_len,
  516. force_output_len=True)
  517. @pytest.mark.parametrize(
  518. "common_llm_kwargs",
  519. [{
  520. "model": "JackFram/llama-68m",
  521. # Skip cuda graph recording for fast test.
  522. "enforce_eager": True,
  523. # Required for spec decode.
  524. "use_v2_block_manager": True
  525. }])
  526. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  527. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  528. @pytest.mark.parametrize(
  529. "test_llm_kwargs",
  530. [
  531. {
  532. "speculative_model": "JackFram/llama-68m",
  533. "num_speculative_tokens": k,
  534. }
  535. # Try a range of common k, as well as large speculation.
  536. for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
  537. ])
  538. @pytest.mark.parametrize("batch_size", [2])
  539. @pytest.mark.parametrize(
  540. "output_len",
  541. [
  542. # Use smaller output len for fast test.
  543. 32,
  544. ])
  545. @pytest.mark.parametrize("seed", [1])
  546. def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
  547. output_len: int):
  548. """Verify that speculative decoding produces exact equality to without spec
  549. decode with many different values of k.
  550. """
  551. run_greedy_equality_correctness_test(baseline_llm_generator,
  552. test_llm_generator,
  553. batch_size,
  554. max_output_len=output_len,
  555. force_output_len=True)
  556. @pytest.mark.parametrize(
  557. "common_llm_kwargs",
  558. [{
  559. "model": "JackFram/llama-160m",
  560. # Skip cuda graph recording for fast test.
  561. "enforce_eager": True,
  562. # Required for spec decode.
  563. "use_v2_block_manager": True
  564. }])
  565. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  566. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  567. @pytest.mark.parametrize(
  568. "test_llm_kwargs",
  569. [
  570. {
  571. "speculative_model": "JackFram/llama-68m",
  572. "num_speculative_tokens": k,
  573. "spec_decoding_acceptance_method": "typical_acceptance_sampler"
  574. }
  575. # Try a range of common k.
  576. for k in [1, 2, 3]
  577. ])
  578. @pytest.mark.parametrize("batch_size", [1, 32])
  579. @pytest.mark.parametrize(
  580. "output_len",
  581. [
  582. # Use smaller output len for fast test.
  583. 32,
  584. ])
  585. @pytest.mark.parametrize("seed", [1])
  586. def test_typical_acceptance_sampling(baseline_llm_generator,
  587. test_llm_generator, batch_size: int,
  588. output_len: int):
  589. """Verify that speculative decoding produces exact equality to without spec
  590. decode with TypicalAcceptanceSampler as the draft token acceptance
  591. sampling method.
  592. """
  593. run_greedy_equality_correctness_test(baseline_llm_generator,
  594. test_llm_generator,
  595. batch_size,
  596. max_output_len=output_len,
  597. force_output_len=True)