test_logprobs.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import pytest
  2. import torch
  3. from aphrodite import SamplingParams
  4. MODELS = ["EleutherAI/pythia-70m-deduped"]
  5. @pytest.mark.parametrize("model", MODELS)
  6. @pytest.mark.parametrize("dtype", ["half"])
  7. def test_get_prompt_logprobs(
  8. hf_runner,
  9. aphrodite_runner,
  10. model,
  11. dtype,
  12. example_prompts,
  13. ):
  14. max_tokens = 5
  15. hf_model = hf_runner(model, dtype=dtype)
  16. hf_logprobs = hf_model.generate_greedy_logprobs(
  17. example_prompts,
  18. max_tokens=max_tokens,
  19. )
  20. del hf_model
  21. aphrodite_model = aphrodite_runner(model, dtype=dtype)
  22. aphrodite_sampling_params = SamplingParams(max_tokens=max_tokens,
  23. logprobs=5,
  24. prompt_logprobs=5,
  25. temperature=0.0)
  26. aphrodite_results = aphrodite_model.model.generate(
  27. example_prompts, sampling_params=aphrodite_sampling_params)
  28. # Test whether logprobs are included in the results.
  29. for result in aphrodite_results:
  30. assert result.prompt_logprobs is not None
  31. assert result.outputs[0].logprobs is not None
  32. # Test whether prompt logprobs are consistent with HF
  33. for aphrodite_result, hf_logprob in zip(aphrodite_results, hf_logprobs):
  34. # Check prompt logprobs
  35. aphrodite_prompt_logprobs = aphrodite_result.prompt_logprobs[1:]
  36. for i, aphrodite_prompt_logprob_dict in enumerate(
  37. aphrodite_prompt_logprobs):
  38. for token_id, logprob in aphrodite_prompt_logprob_dict.items():
  39. torch.testing.assert_close(logprob,
  40. hf_logprob[0][i][token_id].item(),
  41. atol=1e-2,
  42. rtol=1e-2)
  43. aphrodite_sample_logprobs = aphrodite_result.outputs[0].logprobs
  44. for i, aphrodite_sample_logprob_dict in enumerate(
  45. aphrodite_sample_logprobs):
  46. for token_id, logprob in aphrodite_sample_logprob_dict.items():
  47. torch.testing.assert_close(logprob,
  48. hf_logprob[i][-1][token_id].item(),
  49. atol=1e-2,
  50. rtol=1e-2)