test_inputs.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from typing import List
  2. import pytest
  3. from aphrodite.inputs.parse import parse_and_batch_prompt
  4. STRING_INPUTS = [
  5. '',
  6. 'foo',
  7. 'foo bar',
  8. 'foo baz bar',
  9. 'foo bar qux baz',
  10. ]
  11. TOKEN_INPUTS = [
  12. [-1],
  13. [1],
  14. [1, 2],
  15. [1, 3, 4],
  16. [1, 2, 4, 3],
  17. ]
  18. INPUTS_SLICES = [
  19. slice(None, None, -1),
  20. slice(None, None, 2),
  21. slice(None, None, -2),
  22. ]
  23. def test_parse_single_batch_empty():
  24. with pytest.raises(ValueError, match="at least one prompt"):
  25. parse_and_batch_prompt([])
  26. with pytest.raises(ValueError, match="at least one prompt"):
  27. parse_and_batch_prompt([[]])
  28. @pytest.mark.parametrize('string_input', STRING_INPUTS)
  29. def test_parse_single_batch_string_consistent(string_input: str):
  30. assert parse_and_batch_prompt(string_input) \
  31. == parse_and_batch_prompt([string_input])
  32. @pytest.mark.parametrize('token_input', TOKEN_INPUTS)
  33. def test_parse_single_batch_token_consistent(token_input: List[int]):
  34. assert parse_and_batch_prompt(token_input) \
  35. == parse_and_batch_prompt([token_input])
  36. @pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
  37. def test_parse_single_batch_string_slice(inputs_slice: slice):
  38. assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
  39. == parse_and_batch_prompt(STRING_INPUTS[inputs_slice])