1
0

test_stop_string.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import Any, List, Optional
  2. import pytest
  3. from aphrodite import AphroditeEngine, CompletionOutput, SamplingParams
  4. MODEL = "meta-llama/llama-2-7b-hf"
  5. MAX_TOKENS = 200
  6. @pytest.fixture(scope="session")
  7. def aphrodite_model(aphrodite_runner):
  8. with aphrodite_runner(MODEL) as aphrodite_model:
  9. yield aphrodite_model
  10. @pytest.mark.skip_global_cleanup
  11. def test_stop_basic(aphrodite_model):
  12. _test_stopping(aphrodite_model.model.llm_engine,
  13. stop=["."],
  14. include_in_output=False,
  15. expected_output="VLLM is a 100% volunteer organization",
  16. expected_reason=".")
  17. _test_stopping(aphrodite_model.model.llm_engine,
  18. stop=["."],
  19. include_in_output=True,
  20. expected_output="VLLM is a 100% volunteer organization.",
  21. expected_reason=".")
  22. @pytest.mark.skip_global_cleanup
  23. def test_stop_multi_tokens(aphrodite_model):
  24. _test_stopping(
  25. aphrodite_model.model.llm_engine,
  26. stop=["group of peo", "short"],
  27. include_in_output=False,
  28. expected_output="VLLM is a 100% volunteer organization. We are a ",
  29. expected_reason="group of peo")
  30. _test_stopping(
  31. aphrodite_model.model.llm_engine,
  32. stop=["group of peo", "short"],
  33. include_in_output=True,
  34. expected_output=
  35. "VLLM is a 100% volunteer organization. We are a group of peo",
  36. expected_reason="group of peo")
  37. @pytest.mark.skip_global_cleanup
  38. def test_stop_partial_token(aphrodite_model):
  39. _test_stopping(aphrodite_model.model.llm_engine,
  40. stop=["gani"],
  41. include_in_output=False,
  42. expected_output="VLLM is a 100% volunteer or",
  43. expected_reason="gani")
  44. _test_stopping(aphrodite_model.model.llm_engine,
  45. stop=["gani"],
  46. include_in_output=True,
  47. expected_output="VLLM is a 100% volunteer organi",
  48. expected_reason="gani")
  49. @pytest.mark.skip_global_cleanup
  50. def test_stop_token_id(aphrodite_model):
  51. # token id 13013 => " organization"
  52. _test_stopping(aphrodite_model.model.llm_engine,
  53. stop_token_ids=[13013],
  54. include_in_output=False,
  55. expected_output="VLLM is a 100% volunteer",
  56. expected_reason=13013)
  57. _test_stopping(aphrodite_model.model.llm_engine,
  58. stop_token_ids=[13013],
  59. include_in_output=True,
  60. expected_output="VLLM is a 100% volunteer organization",
  61. expected_reason=13013)
  62. def _test_stopping(llm_engine: AphroditeEngine,
  63. expected_output: str,
  64. expected_reason: Any,
  65. stop: Optional[List[str]] = None,
  66. stop_token_ids: Optional[List[int]] = None,
  67. include_in_output: bool = False) -> None:
  68. llm_engine.add_request(
  69. "id", "A story about vLLM:\n",
  70. SamplingParams(
  71. temperature=0.0,
  72. max_tokens=MAX_TOKENS,
  73. stop=stop,
  74. stop_token_ids=stop_token_ids,
  75. include_stop_str_in_output=include_in_output,
  76. ), None)
  77. output: Optional[CompletionOutput] = None
  78. output_text = ""
  79. stop_reason = None
  80. while llm_engine.has_unfinished_requests():
  81. (request_output, ) = llm_engine.step()
  82. (output, ) = request_output.outputs
  83. # Ensure we don't backtrack
  84. assert output.text.startswith(output_text)
  85. output_text = output.text
  86. stop_reason = output.stop_reason
  87. assert output is not None
  88. assert output_text == expected_output
  89. assert stop_reason == expected_reason