test_ignore_eos.py 1.0 KB

123456789101112131415161718192021222324252627282930313233
  1. """Make sure ignore_eos works.
  2. Run `pytest tests/samplers/test_ignore_eos.py`.
  3. """
  4. import pytest
  5. from aphrodite import SamplingParams
  6. # We also test with llama because it has generation_config to specify EOS
  7. # (past regression).
  8. MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
  9. @pytest.mark.parametrize("model", MODELS)
  10. @pytest.mark.parametrize("dtype", ["half"])
  11. @pytest.mark.parametrize("max_tokens", [512])
  12. def test_ignore_eos(
  13. aphrodite_runner,
  14. example_prompts,
  15. model: str,
  16. dtype: str,
  17. max_tokens: int,
  18. ) -> None:
  19. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  20. sampling_params = SamplingParams(max_tokens=max_tokens,
  21. ignore_eos=True)
  22. for prompt in example_prompts:
  23. ignore_eos_output = aphrodite_model.model.generate(
  24. prompt, sampling_params=sampling_params)
  25. output_length = len(ignore_eos_output[0].outputs[0].token_ids)
  26. assert output_length == max_tokens