test_logprobs.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from typing import List
  2. import pytest
  3. import torch
  4. from aphrodite import SamplingParams
  5. from ..conftest import AphroditeRunner
  6. MODELS = ["facebook/opt-125m"]
  7. @pytest.mark.parametrize("model", MODELS)
  8. @pytest.mark.parametrize("dtype",
  9. ["float"]) # needed for comparing logprobs with HF
  10. @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
  11. @pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
  12. @pytest.mark.parametrize("detokenize", [True, False])
  13. def test_get_prompt_logprobs(
  14. hf_runner,
  15. aphrodite_runner,
  16. model,
  17. dtype,
  18. chunked_prefill_token_size: int,
  19. num_top_logprobs: int,
  20. detokenize: bool,
  21. example_prompts,
  22. ):
  23. max_num_seqs = 256
  24. enable_chunked_prefill = False
  25. max_num_batched_tokens = None
  26. if chunked_prefill_token_size != -1:
  27. enable_chunked_prefill = True
  28. max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
  29. max_num_batched_tokens = chunked_prefill_token_size
  30. max_tokens = 5
  31. with hf_runner(model, dtype=dtype) as hf_model:
  32. hf_logprobs = hf_model.generate_greedy_logprobs(
  33. example_prompts,
  34. max_tokens=max_tokens,
  35. )
  36. with aphrodite_runner(
  37. model,
  38. dtype=dtype,
  39. max_logprobs=num_top_logprobs,
  40. enable_chunked_prefill=enable_chunked_prefill,
  41. max_num_batched_tokens=max_num_batched_tokens,
  42. max_num_seqs=max_num_seqs,
  43. ) as aphrodite_model:
  44. aphrodite_sampling_params = SamplingParams(max_tokens=max_tokens,
  45. logprobs=num_top_logprobs,
  46. prompt_logprobs=num_top_logprobs,
  47. temperature=0.0,
  48. detokenize=detokenize)
  49. aphrodite_results = aphrodite_model.model.generate(
  50. example_prompts, sampling_params=aphrodite_sampling_params)
  51. # Test whether logprobs are included in the results.
  52. for result in aphrodite_results:
  53. assert result.prompt_logprobs is not None
  54. assert result.outputs[0].logprobs is not None
  55. assert len(result.outputs[0].logprobs) == max_tokens
  56. for logprobs in result.outputs[0].logprobs:
  57. # If the output token is not included in the top X
  58. # logprob, it can return 1 more data
  59. assert (len(logprobs) == num_top_logprobs
  60. or len(logprobs) == num_top_logprobs + 1)
  61. output_text = result.outputs[0].text
  62. output_string_from_most_likely_tokens_lst: List[str] = []
  63. for top_logprobs in result.outputs[0].logprobs:
  64. top_logprob = next(iter(top_logprobs.values()))
  65. output_string_from_most_likely_tokens_lst.append(
  66. top_logprob.decoded_token)
  67. if detokenize:
  68. output_string_from_most_likely_tokens = "".join(
  69. output_string_from_most_likely_tokens_lst)
  70. assert output_text == output_string_from_most_likely_tokens, (
  71. "The output text from the top logprob for each token position "
  72. "should be the same as the output text in the result.")
  73. else:
  74. assert output_text == ''
  75. assert output_string_from_most_likely_tokens_lst == ([None] *
  76. max_tokens)
  77. # The first prompt logprob is always None
  78. assert result.prompt_logprobs[0] is None
  79. for prompt_logprobs in result.prompt_logprobs[1:]:
  80. # If the prompt token is not included in the top X
  81. # logprob, it can return 1 more data
  82. assert (len(prompt_logprobs) == num_top_logprobs
  83. or len(prompt_logprobs) == num_top_logprobs + 1)
  84. # Test whether prompt logprobs are consistent with HF
  85. for aphrodite_result, hf_logprob in zip(aphrodite_results, hf_logprobs):
  86. # Check prompt logprobs
  87. # The first prompt logprob is always None, so we compare it from 1:.
  88. aphrodite_prompt_logprobs = aphrodite_result.prompt_logprobs[1:]
  89. for i, aphrodite_prompt_logprob_dict in enumerate(
  90. aphrodite_prompt_logprobs):
  91. for token_id, logprob in aphrodite_prompt_logprob_dict.items():
  92. torch.testing.assert_close(logprob.logprob,
  93. hf_logprob[0][i][token_id].item(),
  94. atol=1e-2,
  95. rtol=1e-2)
  96. aphrodite_sample_logprobs = aphrodite_result.outputs[0].logprobs
  97. for i, top_logprobs in enumerate(aphrodite_sample_logprobs):
  98. for token_id, sample_logprob in top_logprobs.items():
  99. logprob = sample_logprob.logprob
  100. torch.testing.assert_close(logprob,
  101. hf_logprob[i][-1][token_id].item(),
  102. atol=1e-2,
  103. rtol=1e-2)
  104. if detokenize:
  105. assert isinstance(sample_logprob.decoded_token, str), (
  106. "The token should be decoded by the time it is returned"
  107. " to the user.")
  108. # Test if prompt logprobs are correctly set.
  109. for aphrodite_result in aphrodite_results:
  110. token_ids = aphrodite_result.prompt_token_ids
  111. prompt_logprobs = aphrodite_result.prompt_logprobs
  112. # The first token doesn't have logprob.
  113. assert prompt_logprobs[0] is None
  114. for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
  115. assert token_id in logprob_dict
  116. def test_max_logprobs():
  117. runner = AphroditeRunner("facebook/opt-125m", max_logprobs=1)
  118. aphrodite_sampling_params = SamplingParams(logprobs=1)
  119. # should pass
  120. runner.generate(["Hello world"], sampling_params=aphrodite_sampling_params)
  121. bad_sampling_params = SamplingParams(logprobs=2)
  122. with pytest.raises(ValueError):
  123. runner.generate(["Hello world"], sampling_params=bad_sampling_params)
  124. @pytest.mark.parametrize("model", MODELS)
  125. @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
  126. @pytest.mark.parametrize("detokenize", [True, False])
  127. def test_none_logprobs(aphrodite_runner, model, chunked_prefill_token_size: int,
  128. detokenize: bool, example_prompts):
  129. max_num_seqs = 256
  130. enable_chunked_prefill = False
  131. max_num_batched_tokens = None
  132. if chunked_prefill_token_size != -1:
  133. enable_chunked_prefill = True
  134. max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
  135. max_num_batched_tokens = chunked_prefill_token_size
  136. max_tokens = 5
  137. with aphrodite_runner(
  138. model,
  139. enable_chunked_prefill=enable_chunked_prefill,
  140. max_num_batched_tokens=max_num_batched_tokens,
  141. max_num_seqs=max_num_seqs,
  142. ) as aphrodite_model:
  143. sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
  144. logprobs=None,
  145. temperature=0.0,
  146. detokenize=detokenize)
  147. results_logprobs_none = aphrodite_model.model.generate(
  148. example_prompts, sampling_params=sampling_params_logprobs_none)
  149. for i in range(len(results_logprobs_none)):
  150. assert results_logprobs_none[i].outputs[0].logprobs is None
  151. assert results_logprobs_none[i].outputs[0].cumulative_logprob is None