test_ranks.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import pytest
  2. from aphrodite import SamplingParams
  3. MODELS = ["facebook/opt-125m"]
  4. @pytest.mark.parametrize("model", MODELS)
  5. @pytest.mark.parametrize("dtype", ["half"])
  6. def test_ranks(
  7. aphrodite_runner,
  8. model,
  9. dtype,
  10. example_prompts,
  11. ):
  12. max_tokens = 5
  13. num_top_logprobs = 5
  14. num_prompt_logprobs = 5
  15. with aphrodite_runner(model, dtype=dtype,
  16. max_logprobs=num_top_logprobs) as aphrodite_model:
  17. ## Test greedy logprobs ranks
  18. aphrodite_sampling_params = SamplingParams(
  19. temperature=0.0,
  20. top_p=1.0,
  21. max_tokens=max_tokens,
  22. logprobs=num_top_logprobs,
  23. prompt_logprobs=num_prompt_logprobs)
  24. aphrodite_results = aphrodite_model.generate_w_logprobs(example_prompts,
  25. aphrodite_sampling_params)
  26. ## Test non-greedy logprobs ranks
  27. sampling_params = SamplingParams(temperature=1.0,
  28. top_p=1.0,
  29. max_tokens=max_tokens,
  30. logprobs=num_top_logprobs,
  31. prompt_logprobs=num_prompt_logprobs)
  32. res = aphrodite_model.generate_w_logprobs(example_prompts,
  33. sampling_params)
  34. for result in aphrodite_results:
  35. assert result[2] is not None
  36. assert len(result[2]) == len(result[0])
  37. # check whether all chosen tokens have ranks = 1
  38. for token, logprobs in zip(result[0], result[2]):
  39. assert token in logprobs
  40. assert logprobs[token].rank == 1
  41. for result in res:
  42. assert result[2] is not None
  43. assert len(result[2]) == len(result[0])
  44. # check whether all chosen tokens have ranks
  45. for token, logprobs in zip(result[0], result[2]):
  46. assert logprobs[token].rank >= 1