test_logprobs.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import math
  2. from itertools import cycle
  3. import pytest
  4. from aphrodite import SamplingParams
  5. from .conftest import get_logprobs_from_llm_generator
  6. @pytest.mark.parametrize(
  7. "common_llm_kwargs",
  8. [{
  9. "model": "JackFram/llama-68m",
  10. # Skip cuda graph recording for fast test.
  11. "enforce_eager": True,
  12. # Required for spec decode.
  13. "use_v2_block_manager": True,
  14. "max_logprobs": 6,
  15. }])
  16. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  17. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  18. @pytest.mark.parametrize("test_llm_kwargs",
  19. [{
  20. "speculative_model": "JackFram/llama-160m",
  21. "num_speculative_tokens": 3,
  22. "disable_logprobs_during_spec_decoding": False,
  23. }])
  24. @pytest.mark.parametrize("batch_size", [8])
  25. @pytest.mark.parametrize(
  26. "output_len",
  27. [
  28. # Use smaller output len for fast test.
  29. 7,
  30. ])
  31. @pytest.mark.parametrize("seed", [1])
  32. def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
  33. batch_size: int, output_len: int):
  34. """Verify output logprobs are equal with and without speculative decoding.
  35. """
  36. run_greedy_logprobs_correctness_test(baseline_llm_generator,
  37. test_llm_generator,
  38. batch_size,
  39. max_output_len=output_len,
  40. force_output_len=True)
  41. @pytest.mark.parametrize(
  42. "common_llm_kwargs",
  43. [{
  44. "model": "JackFram/llama-68m",
  45. # Skip cuda graph recording for fast test.
  46. "enforce_eager": True,
  47. # Required for spec decode.
  48. "use_v2_block_manager": True,
  49. "max_logprobs": 6,
  50. }])
  51. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  52. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  53. @pytest.mark.parametrize("test_llm_kwargs",
  54. [{
  55. "speculative_model": "JackFram/llama-160m",
  56. "num_speculative_tokens": 3,
  57. "disable_logprobs_during_spec_decoding": False,
  58. }])
  59. @pytest.mark.parametrize("batch_size", [1])
  60. @pytest.mark.parametrize("num_logprobs", [6])
  61. @pytest.mark.parametrize(
  62. "output_len",
  63. [
  64. # Use smaller output len for fast test.
  65. 7,
  66. ])
  67. @pytest.mark.parametrize("seed", [1])
  68. def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
  69. batch_size: int, output_len: int,
  70. num_logprobs: int):
  71. """Verify output logprobs are equal with and without spec decode.
  72. This specifies a number of logprobs >1.
  73. """
  74. run_greedy_logprobs_correctness_test(baseline_llm_generator,
  75. test_llm_generator,
  76. batch_size,
  77. max_output_len=output_len,
  78. force_output_len=True,
  79. logprob_rank=num_logprobs)
  80. @pytest.mark.parametrize(
  81. "common_llm_kwargs",
  82. [{
  83. "model": "JackFram/llama-68m",
  84. # Skip cuda graph recording for fast test.
  85. "enforce_eager": True,
  86. # Required for spec decode.
  87. "use_v2_block_manager": True
  88. }])
  89. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  90. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  91. @pytest.mark.parametrize("test_llm_kwargs",
  92. [{
  93. "speculative_model": "JackFram/llama-160m",
  94. "num_speculative_tokens": 3,
  95. "disable_logprobs_during_spec_decoding": False,
  96. }, {
  97. "speculative_model": "JackFram/llama-160m",
  98. "num_speculative_tokens": 6,
  99. "disable_logprobs_during_spec_decoding": False,
  100. }])
  101. @pytest.mark.parametrize("batch_size", [8])
  102. @pytest.mark.parametrize(
  103. "output_len",
  104. [
  105. # Use smaller output len for fast test.
  106. 32,
  107. ])
  108. @pytest.mark.parametrize("seed", [1])
  109. def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
  110. batch_size: int, output_len: int):
  111. """Veriy logprob greedy equality with different speculation lens.
  112. """
  113. run_greedy_logprobs_correctness_test(baseline_llm_generator,
  114. test_llm_generator,
  115. batch_size,
  116. max_output_len=output_len,
  117. force_output_len=True)
  118. @pytest.mark.parametrize(
  119. "common_llm_kwargs",
  120. [{
  121. "model": "JackFram/llama-68m",
  122. # Skip cuda graph recording for fast test.
  123. "enforce_eager": True,
  124. # Required for spec decode.
  125. "use_v2_block_manager": True
  126. }])
  127. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  128. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  129. @pytest.mark.parametrize(
  130. "test_llm_kwargs",
  131. [{
  132. "speculative_model": "JackFram/llama-160m",
  133. "num_speculative_tokens": 3,
  134. "disable_logprobs_during_spec_decoding": False,
  135. # Artificially limit the draft model max model len; this forces
  136. # aphrodite to skip speculation once the sequences grow beyond
  137. # 32-k tokens.
  138. "speculative_max_model_len": 32,
  139. }])
  140. @pytest.mark.parametrize("batch_size", [8])
  141. @pytest.mark.parametrize(
  142. "output_len",
  143. [
  144. # Use smaller output len for fast test.
  145. 32,
  146. ])
  147. @pytest.mark.parametrize("seed", [1])
  148. def test_logprobs_when_skip_speculation(baseline_llm_generator,
  149. test_llm_generator, batch_size: int,
  150. output_len: int):
  151. """Verify logprobs greedy equality when some sequences skip speculation.
  152. """
  153. run_greedy_logprobs_correctness_test(baseline_llm_generator,
  154. test_llm_generator,
  155. batch_size,
  156. max_output_len=output_len,
  157. force_output_len=True)
  158. @pytest.mark.parametrize(
  159. "common_llm_kwargs",
  160. [{
  161. "model": "JackFram/llama-68m",
  162. # Skip cuda graph recording for fast test.
  163. "enforce_eager": True,
  164. # Required for spec decode.
  165. "use_v2_block_manager": True
  166. }])
  167. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  168. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  169. @pytest.mark.parametrize("test_llm_kwargs",
  170. [{
  171. "speculative_model": "JackFram/llama-160m",
  172. "num_speculative_tokens": 3,
  173. "disable_logprobs_during_spec_decoding": False,
  174. }])
  175. @pytest.mark.parametrize("batch_size", [1])
  176. @pytest.mark.parametrize(
  177. "output_len",
  178. [
  179. # Use smaller output len for fast test.
  180. 32,
  181. ])
  182. @pytest.mark.parametrize("seed", [1])
  183. def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
  184. batch_size: int, output_len: int):
  185. """Verify at least one logprob result has num_logprobs+1, which tests the
  186. case where the sampled token is not in top-k logprobs.
  187. Ideally, this test should validate equality with non-spec by getting
  188. logprobs. This is left as future improvement.
  189. """
  190. batch_size = 8
  191. max_output_len = output_len
  192. force_output_len = True
  193. logprob_rank = 5
  194. temperature = 1.0
  195. prompts = [
  196. "Hello, my name is",
  197. "The president of the United States is",
  198. "The capital of France is",
  199. "The future of AI is",
  200. "San Francisco is know for its",
  201. "Facebook was created in 2004 by",
  202. "Curious George is a",
  203. "Python 3.11 brings improvements to its",
  204. ]
  205. prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
  206. # If the test requires that we generated max_output_len tokens, then set the
  207. # sampling params to ignore eos token.
  208. ignore_eos = force_output_len
  209. sampling_params = SamplingParams(
  210. max_tokens=max_output_len,
  211. ignore_eos=ignore_eos,
  212. temperature=temperature,
  213. logprobs=logprob_rank,
  214. )
  215. spec_batch_logprobs = get_logprobs_from_llm_generator(
  216. test_llm_generator, prompts, sampling_params)
  217. num_returned_logprobs = [
  218. len(logprob_dict) for seq_logprobs in spec_batch_logprobs
  219. for logprob_dict in seq_logprobs
  220. ]
  221. # Assert one of the returned logprobs has > num_logprobs (indicating the
  222. # sampled token is not in top-k).
  223. assert any([
  224. num_returned > logprob_rank for num_returned in num_returned_logprobs
  225. ])
  226. def run_greedy_logprobs_correctness_test(baseline_llm_generator,
  227. test_llm_generator,
  228. batch_size,
  229. max_output_len,
  230. force_output_len: bool,
  231. logprob_rank: int = 1):
  232. """Helper method that compares the logprobs outputs of both the baseline LLM
  233. and the test LLM. It asserts greedy equality of the logprobs when the
  234. temperature is zero.
  235. """
  236. temperature = 0.0
  237. prompts = [
  238. "Hello, my name is",
  239. "The president of the United States is",
  240. "The capital of France is",
  241. "The future of AI is",
  242. "San Francisco is know for its",
  243. "Facebook was created in 2004 by",
  244. "Curious George is a",
  245. "Python 3.11 brings improvements to its",
  246. ]
  247. prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
  248. # If the test requires that we generated max_output_len tokens, then set the
  249. # sampling params to ignore eos token.
  250. ignore_eos = force_output_len
  251. sampling_params = SamplingParams(
  252. max_tokens=max_output_len,
  253. ignore_eos=ignore_eos,
  254. temperature=temperature,
  255. logprobs=logprob_rank,
  256. )
  257. spec_batch_logprobs = get_logprobs_from_llm_generator(
  258. test_llm_generator, prompts, sampling_params)
  259. baseline_batch_logprobs = get_logprobs_from_llm_generator(
  260. baseline_llm_generator, prompts, sampling_params)
  261. assert len(baseline_batch_logprobs) == len(prompts)
  262. assert len(spec_batch_logprobs) == len(prompts)
  263. # For each sequence in the batch.
  264. for i, (baseline_logprobs, spec_logprobs) in enumerate(
  265. zip(baseline_batch_logprobs, spec_batch_logprobs)):
  266. assert len(spec_logprobs) == len(baseline_logprobs)
  267. # For each generated position of the sequence.
  268. for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
  269. zip(spec_logprobs, baseline_logprobs)):
  270. # Map rank to token/logprob in spec output.
  271. spec_rank_to_token_id = {
  272. value.rank: key
  273. for key, value in spec_pos_logprobs.items()
  274. }
  275. spec_rank_to_logprob = {
  276. value.rank: value.logprob
  277. for key, value in spec_pos_logprobs.items()
  278. }
  279. # Map rank to token/logprob in baseline output.
  280. baseline_rank_to_token_id = {
  281. value.rank: key
  282. for key, value in baseline_pos_logprobs.items()
  283. }
  284. baseline_rank_to_logprob = {
  285. value.rank: value.logprob
  286. for key, value in baseline_pos_logprobs.items()
  287. }
  288. # Assert set of ranks returned is equal.
  289. assert set(spec_rank_to_token_id.keys()) == set(
  290. baseline_rank_to_token_id.keys())
  291. # Assert each logprob/token id is correct, keyed by rank.
  292. for rank in sorted(set(spec_rank_to_token_id.keys())):
  293. assert spec_rank_to_token_id[
  294. rank] == baseline_rank_to_token_id[rank], f"{rank}"
  295. assert math.isclose(
  296. a=spec_rank_to_logprob[rank],
  297. b=baseline_rank_to_logprob[rank],
  298. abs_tol=1e-1,
  299. )