test_multistep_correctness.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. """The tests in this file verify end-to-end speculative decoding correctness.
  2. This docstring details important information on the testing methodology.
  3. Most of the tests rely on "greedy equality", where we expect the output of
  4. speculative decoding on a sequence to exactly match the output of normal non-
  5. speculative decoding.
  6. Since speculative decoding with rejection sampling guarantees that the output
  7. distribution matches the target model's output distribution (up to hardware
  8. numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
  9. equality. This gives us good coverage of temp=0.
  10. At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
  11. highest probability in the target distribution are accepted. Therefore, we can
  12. expect greedy equality for the TypicalAcceptanceSampler at temp=0.
  13. For temp>0, we rely on unit tests on the rejection sampler to verify that the
  14. output distribution is the same with spec decode vs. no spec decode (this would
  15. be prohibitively expensive to run with a real model). Similarly, for the
  16. TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
  17. test cases.
  18. NOTE: Speculative decoding's distribution equality requires that the measured
  19. distributions of the target model and proposal model be deterministic given the
  20. same input. Aphrodite largely guarantees this.
  21. @cadedaniel has seen cases where the output probabilities of a draft/target
  22. model change slightly with certain batch sizes or prompts, even with Torch
  23. determinism flags set. It is unclear if this is a bug in Aphrodite, due to non-
  24. determinism in on-device batched operations, a bug in Aphrodite's spec decode
  25. implementation, or the "hardware numerics" limitations. Either way, rejection
  26. sampling ensures the output distribution matches the target model, but it breaks
  27. greedy-equality tests for those batch sizes/prompts.
  28. """
  29. from itertools import cycle
  30. import pytest
  31. from transformers import AutoTokenizer
  32. from aphrodite import SamplingParams
  33. from ...utils import fork_new_process_for_each_test
  34. from .conftest import (get_output_from_llm_generator,
  35. run_equality_correctness_test)
  36. @pytest.mark.parametrize(
  37. "common_llm_kwargs",
  38. [{
  39. # Use a small model for a fast test.
  40. # Note this is repeated in the test body; to initialize a tokenizer.
  41. "model": "JackFram/llama-68m",
  42. # Skip cuda graph recording for fast test.
  43. "enforce_eager": True,
  44. # Required for spec decode.
  45. "use_v2_block_manager": True,
  46. }])
  47. @pytest.mark.parametrize(
  48. "per_test_common_llm_kwargs",
  49. [
  50. {
  51. "speculative_model": "JackFram/llama-68m",
  52. "num_speculative_tokens": 5,
  53. },
  54. {
  55. # Verify the detokenizer assertions in the test work when spec
  56. # decode is disabled.
  57. },
  58. ])
  59. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  60. @pytest.mark.parametrize("batch_size", [1, 32])
  61. @pytest.mark.parametrize("seed", [1])
  62. @fork_new_process_for_each_test
  63. def test_spec_decode_e2e_with_detokenization(test_llm_generator,
  64. batch_size: int):
  65. """Run generation with speculative decoding on a batch. Verify the engine
  66. generates the correct number of tokens (via ignore_eos=True), and that the
  67. detokenization matches HF transformers.
  68. """
  69. output_len = 32
  70. temperature = 0.0
  71. prompts = [
  72. "Hello, my name is",
  73. "The president of the United States is",
  74. "The capital of France is",
  75. "The future of AI is",
  76. ]
  77. prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
  78. sampling_params = SamplingParams(
  79. max_tokens=output_len,
  80. ignore_eos=True,
  81. temperature=temperature,
  82. )
  83. batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
  84. test_llm_generator, prompts, sampling_params)
  85. # Expect a generation for each prompt in the batch.
  86. assert len(batch_token_ids) == len(prompts)
  87. # Expect each generation to have expected number of tokens (note ignore_eos
  88. # is True).
  89. assert [len(token_ids)
  90. for token_ids in batch_token_ids] == ([output_len] * batch_size)
  91. # Expect detokenized string to match.
  92. tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
  93. for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
  94. expected_tokens = tok.decode(actual_token_ids)
  95. print(f"{actual_token_ids=}")
  96. assert actual_tokens.strip() == expected_tokens.strip()
  97. @pytest.mark.parametrize(
  98. "common_llm_kwargs",
  99. [{
  100. # Skip cuda graph recording for fast test.
  101. "enforce_eager": True,
  102. # Required for spec decode.
  103. "use_v2_block_manager": True,
  104. # Print spec metrics.
  105. "disable_log_stats": False,
  106. }])
  107. @pytest.mark.parametrize(
  108. "per_test_common_llm_kwargs",
  109. [
  110. # Try two different tiny base models.
  111. # Note that one is equal to the draft model, another isn't.
  112. {
  113. "model_name": "JackFram/llama-68m",
  114. },
  115. {
  116. "model_name": "JackFram/llama-160m",
  117. },
  118. ])
  119. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  120. @pytest.mark.parametrize("test_llm_kwargs", [
  121. {
  122. "speculative_model": "JackFram/llama-68m",
  123. "num_speculative_tokens": 5,
  124. },
  125. ])
  126. @pytest.mark.parametrize(
  127. "output_len",
  128. [
  129. # Use long output len for the small model test.
  130. 10,
  131. ])
  132. @pytest.mark.parametrize("batch_size", [1])
  133. @pytest.mark.parametrize("seed", [1])
  134. @fork_new_process_for_each_test
  135. def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
  136. aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
  137. baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
  138. seed: int):
  139. """Verify greedy equality on a tiny model with batch size of one.
  140. Since this test is cheaper than other e2e correctness tests, we generate
  141. with a higher output_len.
  142. When the draft model is the same as the target model, we further check
  143. whether all speculative tokens are accepted.
  144. """
  145. ensure_all_accepted = per_test_common_llm_kwargs.get(
  146. "model_name") == test_llm_kwargs.get("speculative_model")
  147. run_equality_correctness_test(aphrodite_runner,
  148. common_llm_kwargs,
  149. per_test_common_llm_kwargs,
  150. baseline_llm_kwargs,
  151. test_llm_kwargs,
  152. batch_size,
  153. max_output_len=output_len,
  154. seed=seed,
  155. temperature=0.0,
  156. ensure_all_accepted=ensure_all_accepted)
  157. @pytest.mark.parametrize(
  158. "common_llm_kwargs",
  159. [{
  160. # Skip cuda graph recording for fast test.
  161. "enforce_eager": True,
  162. # Required for spec decode.
  163. "use_v2_block_manager": True,
  164. # Print spec metrics.
  165. "disable_log_stats": False,
  166. }])
  167. @pytest.mark.parametrize(
  168. "per_test_common_llm_kwargs",
  169. [
  170. # Try two different tiny base models.
  171. # Note that one is equal to the draft model, another isn't.
  172. {
  173. "model_name": "JackFram/llama-68m",
  174. },
  175. {
  176. "model_name": "JackFram/llama-160m",
  177. },
  178. ])
  179. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  180. @pytest.mark.parametrize("test_llm_kwargs", [
  181. {
  182. "speculative_model": "JackFram/llama-68m",
  183. "num_speculative_tokens": 5,
  184. },
  185. ])
  186. @pytest.mark.parametrize(
  187. "output_len",
  188. [
  189. # Use small output len for fast test.
  190. 256,
  191. ])
  192. @pytest.mark.parametrize("batch_size", [64])
  193. @pytest.mark.parametrize("seed", [1])
  194. @fork_new_process_for_each_test
  195. def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
  196. aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
  197. baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
  198. seed: int):
  199. """Verify greedy equality on a tiny model and large batch size.
  200. """
  201. run_equality_correctness_test(aphrodite_runner,
  202. common_llm_kwargs,
  203. per_test_common_llm_kwargs,
  204. baseline_llm_kwargs,
  205. test_llm_kwargs,
  206. batch_size,
  207. max_output_len=output_len,
  208. seed=seed,
  209. temperature=0.0)
  210. @pytest.mark.parametrize(
  211. "common_llm_kwargs",
  212. [{
  213. # Skip cuda graph recording for fast test.
  214. "enforce_eager": True,
  215. # Required for spec decode.
  216. "use_v2_block_manager": True
  217. }])
  218. @pytest.mark.parametrize(
  219. "per_test_common_llm_kwargs",
  220. [
  221. # Try two different tiny base models.
  222. # Note that one is equal to the draft model, another isn't.
  223. {
  224. "model_name": "JackFram/llama-68m",
  225. },
  226. {
  227. "model_name": "JackFram/llama-160m",
  228. },
  229. ])
  230. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  231. @pytest.mark.parametrize("test_llm_kwargs", [
  232. {
  233. "speculative_model": "JackFram/llama-68m",
  234. "num_speculative_tokens": 5,
  235. },
  236. ])
  237. @pytest.mark.parametrize("max_output_len", [
  238. 256,
  239. ])
  240. @pytest.mark.parametrize("batch_size", [32])
  241. @pytest.mark.parametrize("seed", [1])
  242. @fork_new_process_for_each_test
  243. def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
  244. aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
  245. baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
  246. max_output_len: int, seed: int):
  247. """Verify greedy equality on a tiny model, with a large batch size, and when
  248. sampling respects the EOS token.
  249. """
  250. run_equality_correctness_test(aphrodite_runner,
  251. common_llm_kwargs,
  252. per_test_common_llm_kwargs,
  253. baseline_llm_kwargs,
  254. test_llm_kwargs,
  255. batch_size,
  256. max_output_len,
  257. seed=seed,
  258. temperature=0.0,
  259. ignore_eos=False)
  260. @pytest.mark.parametrize(
  261. "common_llm_kwargs",
  262. [{
  263. # A "real" model (not tiny).
  264. "model_name": "meta-llama/Llama-2-7b-chat-hf",
  265. # Skip cuda graph recording for fast test.
  266. "enforce_eager": True,
  267. # Required for spec decode.
  268. "use_v2_block_manager": True,
  269. # Print spec metrics.
  270. "disable_log_stats": False,
  271. }])
  272. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  273. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  274. @pytest.mark.parametrize("test_llm_kwargs", [
  275. {
  276. "speculative_model": "JackFram/llama-68m",
  277. "num_speculative_tokens": 5,
  278. },
  279. ])
  280. @pytest.mark.parametrize("batch_size", [1])
  281. @pytest.mark.parametrize(
  282. "output_len",
  283. [
  284. # Use decently long output len for a high quality test.
  285. 256,
  286. ])
  287. @pytest.mark.parametrize("seed", [1])
  288. @fork_new_process_for_each_test
  289. def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
  290. aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
  291. baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
  292. seed: int):
  293. """Verify greedy equality on a "real" model and batch size of 1. This is
  294. separate from large BS tests to make identifying the source of bugs easier.
  295. """
  296. run_equality_correctness_test(aphrodite_runner,
  297. common_llm_kwargs,
  298. per_test_common_llm_kwargs,
  299. baseline_llm_kwargs,
  300. test_llm_kwargs,
  301. batch_size,
  302. max_output_len=output_len,
  303. seed=seed,
  304. temperature=0.0)
  305. @pytest.mark.parametrize(
  306. "common_llm_kwargs",
  307. [{
  308. # A "real" model (not tiny).
  309. "model_name": "meta-llama/Llama-2-7b-chat-hf",
  310. # Skip cuda graph recording for fast test.
  311. "enforce_eager": True,
  312. # Required for spec decode.
  313. "use_v2_block_manager": True,
  314. # Print spec metrics.
  315. "disable_log_stats": False,
  316. }])
  317. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  318. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  319. @pytest.mark.parametrize("test_llm_kwargs", [
  320. {
  321. "speculative_model": "JackFram/llama-68m",
  322. "num_speculative_tokens": 5,
  323. },
  324. ])
  325. @pytest.mark.parametrize("batch_size", [32])
  326. @pytest.mark.parametrize(
  327. "output_len",
  328. [
  329. # Use smaller output len for fast test.
  330. 64,
  331. ])
  332. @pytest.mark.parametrize("seed", [1])
  333. @fork_new_process_for_each_test
  334. def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
  335. aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
  336. baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
  337. seed: int):
  338. """Verify greedy equality with a "real" model on a nontrivial batch size.
  339. This is the closest test to a real production workload.
  340. """
  341. run_equality_correctness_test(aphrodite_runner,
  342. common_llm_kwargs,
  343. per_test_common_llm_kwargs,
  344. baseline_llm_kwargs,
  345. test_llm_kwargs,
  346. batch_size,
  347. max_output_len=output_len,
  348. seed=seed,
  349. temperature=0.0)
  350. @pytest.mark.parametrize(
  351. "common_llm_kwargs",
  352. [{
  353. "block_size": 8,
  354. # 2 for small prompt, 256//8 for generated.
  355. "num_gpu_blocks_override": 2 + 256 // 8,
  356. "max_model_len": (2 + 256 // 8) * 8,
  357. # Skip cuda graph recording for fast test.
  358. "enforce_eager": True,
  359. # Required for spec decode.
  360. "use_v2_block_manager": True
  361. }])
  362. @pytest.mark.parametrize("per_test_common_llm_kwargs", [
  363. {
  364. "model_name": "JackFram/llama-160m",
  365. },
  366. ])
  367. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  368. @pytest.mark.parametrize("test_llm_kwargs", [
  369. {
  370. "speculative_model": "JackFram/llama-68m",
  371. "num_speculative_tokens": 5,
  372. },
  373. ])
  374. @pytest.mark.parametrize(
  375. "output_len",
  376. [
  377. # Use small output len for fast test.
  378. 256,
  379. ])
  380. @pytest.mark.parametrize("batch_size", [4])
  381. @pytest.mark.parametrize("seed", [1])
  382. @fork_new_process_for_each_test
  383. def test_spec_decode_e2e_greedy_correctness_with_preemption(
  384. aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
  385. baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
  386. seed: int):
  387. """Verify greedy equality, even when some sequences are preempted mid-
  388. generation.
  389. """
  390. run_equality_correctness_test(aphrodite_runner,
  391. common_llm_kwargs,
  392. per_test_common_llm_kwargs,
  393. baseline_llm_kwargs,
  394. test_llm_kwargs,
  395. batch_size,
  396. max_output_len=output_len,
  397. seed=seed,
  398. temperature=0.0)
  399. @pytest.mark.parametrize(
  400. "common_llm_kwargs",
  401. [{
  402. "model_name": "JackFram/llama-160m",
  403. # Skip cuda graph recording for fast test.
  404. "enforce_eager": True,
  405. # Required for spec decode.
  406. "use_v2_block_manager": True
  407. }])
  408. @pytest.mark.parametrize(
  409. "per_test_common_llm_kwargs",
  410. [
  411. # As of this writing, Aphrodite only compiles with these 3 block sizes
  412. # by default.
  413. {
  414. "block_size": 8,
  415. },
  416. {
  417. "block_size": 16,
  418. },
  419. {
  420. "block_size": 32,
  421. },
  422. ])
  423. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  424. @pytest.mark.parametrize("test_llm_kwargs", [
  425. {
  426. "speculative_model": "JackFram/llama-68m",
  427. "num_speculative_tokens": 5,
  428. },
  429. ])
  430. @pytest.mark.parametrize("batch_size", [2])
  431. @pytest.mark.parametrize(
  432. "output_len",
  433. [
  434. # Use smaller output len for fast test.
  435. 32,
  436. ])
  437. @pytest.mark.parametrize("seed", [1])
  438. @fork_new_process_for_each_test
  439. def test_spec_decode_different_block_size(aphrodite_runner, common_llm_kwargs,
  440. per_test_common_llm_kwargs,
  441. baseline_llm_kwargs, test_llm_kwargs,
  442. batch_size: int, output_len: int,
  443. seed: int):
  444. """Verify greedy equality over different block sizes.
  445. """
  446. run_equality_correctness_test(aphrodite_runner,
  447. common_llm_kwargs,
  448. per_test_common_llm_kwargs,
  449. baseline_llm_kwargs,
  450. test_llm_kwargs,
  451. batch_size,
  452. max_output_len=output_len,
  453. seed=seed,
  454. temperature=0.0)
  455. @pytest.mark.parametrize(
  456. "common_llm_kwargs",
  457. [{
  458. "model_name": "JackFram/llama-160m",
  459. # Skip cuda graph recording for fast test.
  460. "enforce_eager": True,
  461. # Required for spec decode.
  462. "use_v2_block_manager": True
  463. }])
  464. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  465. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  466. @pytest.mark.parametrize(
  467. "test_llm_kwargs",
  468. [
  469. {
  470. "speculative_model": "JackFram/llama-68m",
  471. "num_speculative_tokens": 5,
  472. # Artificially limit the draft model max model len; this forces
  473. # aphrodite to skip speculation once the sequences grow beyond
  474. # 32-k tokens.
  475. "speculative_max_model_len": 32,
  476. },
  477. ])
  478. @pytest.mark.parametrize("batch_size", [8])
  479. @pytest.mark.parametrize(
  480. "output_len",
  481. [
  482. # This must be a good bit larger than speculative_max_model_len so that
  483. # we can test the case where all seqs are skipped, but still small to
  484. # ensure fast test.
  485. 64,
  486. ])
  487. @pytest.mark.parametrize("seed", [1])
  488. @fork_new_process_for_each_test
  489. def test_skip_speculation(aphrodite_runner, common_llm_kwargs,
  490. per_test_common_llm_kwargs, baseline_llm_kwargs,
  491. test_llm_kwargs, batch_size: int, output_len: int,
  492. seed: int):
  493. """Verify greedy equality when some (or all) sequences skip speculation.
  494. We do this by setting the max model len of the draft model to an
  495. artificially low value, such that when the sequences grow beyond it, they
  496. are skipped in speculative decoding.
  497. """
  498. run_equality_correctness_test(aphrodite_runner,
  499. common_llm_kwargs,
  500. per_test_common_llm_kwargs,
  501. baseline_llm_kwargs,
  502. test_llm_kwargs,
  503. batch_size,
  504. max_output_len=output_len,
  505. seed=seed,
  506. temperature=0.0)
  507. @pytest.mark.parametrize(
  508. "common_llm_kwargs",
  509. [{
  510. "model_name": "JackFram/llama-160m",
  511. # Skip cuda graph recording for fast test.
  512. "enforce_eager": True,
  513. # Required for spec decode.
  514. "use_v2_block_manager": True
  515. }])
  516. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  517. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  518. @pytest.mark.parametrize("test_llm_kwargs", [
  519. {
  520. "speculative_model": "JackFram/llama-68m",
  521. "num_speculative_tokens": 5,
  522. "speculative_disable_by_batch_size": 2,
  523. },
  524. ])
  525. @pytest.mark.parametrize("batch_size", [8])
  526. @pytest.mark.parametrize("output_len", [10])
  527. @pytest.mark.parametrize("seed", [1])
  528. @fork_new_process_for_each_test
  529. def test_disable_speculation(aphrodite_runner, common_llm_kwargs,
  530. per_test_common_llm_kwargs, baseline_llm_kwargs,
  531. test_llm_kwargs, batch_size: int, output_len: int,
  532. seed: int):
  533. """Verify greedy equality when all sequences disable speculation.
  534. """
  535. run_equality_correctness_test(aphrodite_runner,
  536. common_llm_kwargs,
  537. per_test_common_llm_kwargs,
  538. baseline_llm_kwargs,
  539. test_llm_kwargs,
  540. batch_size,
  541. max_output_len=output_len,
  542. seed=seed,
  543. temperature=0.0)
  544. @pytest.mark.parametrize(
  545. "common_llm_kwargs",
  546. [{
  547. "model_name": "JackFram/llama-68m",
  548. # Skip cuda graph recording for fast test.
  549. "enforce_eager": True,
  550. # Required for spec decode.
  551. "use_v2_block_manager": True
  552. }])
  553. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  554. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  555. @pytest.mark.parametrize(
  556. "test_llm_kwargs",
  557. [
  558. {
  559. "speculative_model": "JackFram/llama-68m",
  560. "num_speculative_tokens": k,
  561. }
  562. # Try a range of common k, as well as large speculation.
  563. for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
  564. ])
  565. @pytest.mark.parametrize("batch_size", [2])
  566. @pytest.mark.parametrize(
  567. "output_len",
  568. [
  569. # Use smaller output len for fast test.
  570. 32,
  571. ])
  572. @pytest.mark.parametrize("seed", [1])
  573. @fork_new_process_for_each_test
  574. def test_many_k(aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
  575. baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
  576. output_len: int, seed: int):
  577. """Verify that speculative decoding produces exact equality to without spec
  578. decode with many different values of k.
  579. """
  580. run_equality_correctness_test(aphrodite_runner,
  581. common_llm_kwargs,
  582. per_test_common_llm_kwargs,
  583. baseline_llm_kwargs,
  584. test_llm_kwargs,
  585. batch_size,
  586. max_output_len=output_len,
  587. seed=seed,
  588. temperature=0.0)
  589. @pytest.mark.parametrize(
  590. "common_llm_kwargs",
  591. [{
  592. "model_name": "JackFram/llama-160m",
  593. # Skip cuda graph recording for fast test.
  594. "enforce_eager": True,
  595. # Required for spec decode.
  596. "use_v2_block_manager": True
  597. }])
  598. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  599. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  600. @pytest.mark.parametrize(
  601. "test_llm_kwargs",
  602. [
  603. {
  604. "speculative_model": "JackFram/llama-68m",
  605. "num_speculative_tokens": k,
  606. "spec_decoding_acceptance_method": "typical_acceptance_sampler"
  607. }
  608. # Try a range of common k.
  609. for k in [1, 2, 3]
  610. ])
  611. @pytest.mark.parametrize("batch_size", [1, 32])
  612. @pytest.mark.parametrize(
  613. "output_len",
  614. [
  615. # Use smaller output len for fast test.
  616. 32,
  617. ])
  618. @pytest.mark.parametrize("seed", [1])
  619. @fork_new_process_for_each_test
  620. def test_typical_acceptance_sampling(aphrodite_runner, common_llm_kwargs,
  621. per_test_common_llm_kwargs,
  622. baseline_llm_kwargs, test_llm_kwargs,
  623. batch_size: int, output_len: int,
  624. seed: int):
  625. """Verify that speculative decoding produces exact equality to without spec
  626. decode with TypicalAcceptanceSampler as the draft token acceptance
  627. sampling method.
  628. """
  629. run_equality_correctness_test(aphrodite_runner,
  630. common_llm_kwargs,
  631. per_test_common_llm_kwargs,
  632. baseline_llm_kwargs,
  633. test_llm_kwargs,
  634. batch_size,
  635. max_output_len=output_len,
  636. seed=seed,
  637. temperature=0.0)