test_stop_checker.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from unittest.mock import MagicMock
  2. import pytest
  3. from transformers import PreTrainedTokenizer
  4. from aphrodite.common.sampling_params import SamplingParams
  5. from aphrodite.common.sequence import Logprob, Sequence, SequenceStatus
  6. from aphrodite.engine.output_processor.stop_checker import StopChecker
  7. def sequence_with_eos(text: str, eos_token: str,
  8. eos_token_id: int) -> Sequence:
  9. """
  10. Create a Sequence that ends with an EOS token.
  11. """
  12. seq = Sequence(
  13. seq_id=0,
  14. inputs={"prompt_token_ids": []},
  15. block_size=16,
  16. eos_token_id=eos_token_id,
  17. )
  18. seq.output_text = text + eos_token
  19. offset = eos_token_id + 1
  20. for i in range(offset, len(text) + offset):
  21. seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)})
  22. seq.append_token_id(token_id=eos_token_id,
  23. logprobs={eos_token_id: Logprob(0.0)})
  24. seq.status = SequenceStatus.RUNNING
  25. return seq
  26. @pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
  27. ("This text ends with EOS token", "</s>", 2),
  28. ])
  29. @pytest.mark.parametrize("ignore_eos", [True, False])
  30. @pytest.mark.parametrize("include_stop_str_in_output", [True, False])
  31. @pytest.mark.skip_global_cleanup
  32. def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
  33. ignore_eos: bool, include_stop_str_in_output: bool):
  34. """
  35. Test the behavior of the StopChecker's maybe_stop_sequence method
  36. when an EOS token is encountered.
  37. This test covers:
  38. - When the EOS token should stop the sequence and be removed from the output
  39. - When the EOS token should stop the sequence and be included in the output
  40. - When the EOS token should be ignored, and the sequence continues
  41. """
  42. tokenizer = MagicMock(spec=PreTrainedTokenizer)
  43. get_tokenizer_for_seq = MagicMock(return_value=tokenizer)
  44. stop_checker = StopChecker(max_model_len=1024,
  45. get_tokenizer_for_seq=get_tokenizer_for_seq)
  46. seq = sequence_with_eos(
  47. text=text_wo_eos,
  48. eos_token=eos_token,
  49. eos_token_id=eos_token_id,
  50. )
  51. new_char_count = len(eos_token)
  52. # Note that `stop` and `stop_token_ids` are not specified
  53. sampling_params = SamplingParams(
  54. min_tokens=1,
  55. ignore_eos=ignore_eos,
  56. include_stop_str_in_output=include_stop_str_in_output)
  57. stop_checker.maybe_stop_sequence(
  58. seq=seq,
  59. new_char_count=new_char_count,
  60. sampling_params=sampling_params,
  61. )
  62. if ignore_eos:
  63. assert seq.status == SequenceStatus.RUNNING
  64. assert seq.output_text == text_wo_eos + eos_token
  65. elif include_stop_str_in_output:
  66. assert seq.status == SequenceStatus.FINISHED_STOPPED
  67. assert seq.output_text == text_wo_eos + eos_token
  68. else:
  69. assert seq.status == SequenceStatus.FINISHED_STOPPED
  70. assert seq.output_text == text_wo_eos