test_stop_reason.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. """Test the different finish_reason="stop" situations during generation:
  2. 1. One of the provided stop strings
  3. 2. One of the provided stop tokens
  4. 3. The EOS token
  5. Run `pytest tests/engine/test_stop_reason.py`.
  6. """
  7. import pytest
  8. import transformers
  9. from aphrodite import SamplingParams
  10. MODEL = "facebook/opt-350m"
  11. STOP_STR = "."
  12. SEED = 42
  13. MAX_TOKENS = 1024
  14. @pytest.fixture
  15. def aphrodite_model(aphrodite_runner):
  16. aphrodite_model = aphrodite_runner(MODEL)
  17. yield aphrodite_model
  18. del aphrodite_model
  19. def test_stop_reason(aphrodite_model, example_prompts):
  20. tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
  21. stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
  22. llm = aphrodite_model.model
  23. # test stop token
  24. outputs = llm.generate(example_prompts,
  25. sampling_params=SamplingParams(
  26. seed=SEED,
  27. max_tokens=MAX_TOKENS,
  28. stop_token_ids=[stop_token_id]))
  29. for output in outputs:
  30. output = output.outputs[0]
  31. assert output.finish_reason == "stop"
  32. assert output.stop_reason == stop_token_id
  33. # test stop string
  34. outputs = llm.generate(example_prompts,
  35. sampling_params=SamplingParams(
  36. seed=SEED, max_tokens=MAX_TOKENS, stop="."))
  37. for output in outputs:
  38. output = output.outputs[0]
  39. assert output.finish_reason == "stop"
  40. assert output.stop_reason == STOP_STR
  41. # test EOS token
  42. outputs = llm.generate(example_prompts,
  43. sampling_params=SamplingParams(
  44. seed=SEED, max_tokens=MAX_TOKENS))
  45. for output in outputs:
  46. output = output.outputs[0]
  47. assert output.finish_reason == "length" or (
  48. output.finish_reason == "stop" and output.stop_reason is None)