test_stop_reason.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. """Test the different finish_reason="stop" situations during generation:
  2. 1. One of the provided stop strings
  3. 2. One of the provided stop tokens
  4. 3. The EOS token
  5. Run `pytest tests/engine/test_stop_reason.py`.
  6. """
  7. import pytest
  8. import transformers
  9. from aphrodite import SamplingParams
  10. MODEL = "facebook/opt-350m"
  11. STOP_STR = "."
  12. SEED = 42
  13. MAX_TOKENS = 1024
  14. @pytest.fixture
  15. def aphrodite_model(aphrodite_runner):
  16. with aphrodite_runner(MODEL) as aphrodite_model:
  17. yield aphrodite_model
  18. def test_stop_reason(aphrodite_model, example_prompts):
  19. tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
  20. stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
  21. llm = aphrodite_model.model
  22. # test stop token
  23. outputs = llm.generate(example_prompts,
  24. sampling_params=SamplingParams(
  25. ignore_eos=True,
  26. seed=SEED,
  27. max_tokens=MAX_TOKENS,
  28. stop_token_ids=[stop_token_id]))
  29. for output in outputs:
  30. output = output.outputs[0]
  31. assert output.finish_reason == "stop"
  32. assert output.stop_reason == stop_token_id
  33. # test stop string
  34. outputs = llm.generate(example_prompts,
  35. sampling_params=SamplingParams(
  36. ignore_eos=True,
  37. seed=SEED,
  38. max_tokens=MAX_TOKENS,
  39. stop="."))
  40. for output in outputs:
  41. output = output.outputs[0]
  42. assert output.finish_reason == "stop"
  43. assert output.stop_reason == STOP_STR
  44. # test EOS token
  45. outputs = llm.generate(example_prompts,
  46. sampling_params=SamplingParams(
  47. seed=SEED, max_tokens=MAX_TOKENS))
  48. for output in outputs:
  49. output = output.outputs[0]
  50. assert output.finish_reason == "length" or (
  51. output.finish_reason == "stop" and output.stop_reason is None)