test_stop_string.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from typing import Any, List, Optional
  2. import pytest
  3. from aphrodite import AphroditeEngine, CompletionOutput, SamplingParams
  4. MODEL = "meta-llama/llama-2-7b-hf"
  5. MAX_TOKENS = 200
  6. IS_ASYNC = False
  7. @pytest.fixture(scope="session")
  8. def aphrodite_model(aphrodite_runner):
  9. with aphrodite_runner(MODEL) as aphrodite_model:
  10. yield aphrodite_model
  11. def _test_stopping(llm_engine: AphroditeEngine,
  12. expected_output: str,
  13. expected_reason: Any,
  14. stop: Optional[List[str]] = None,
  15. stop_token_ids: Optional[List[int]] = None,
  16. include_in_output: bool = False,
  17. use_async_output_proc: bool = False) -> None:
  18. llm_engine.add_request(
  19. "id", "A story about Aphrodite:\n",
  20. SamplingParams(
  21. temperature=0.0,
  22. max_tokens=MAX_TOKENS,
  23. stop=stop,
  24. stop_token_ids=stop_token_ids,
  25. include_stop_str_in_output=include_in_output,
  26. ), None)
  27. output: Optional[CompletionOutput] = None
  28. output_text = ""
  29. stop_reason = None
  30. if use_async_output_proc:
  31. llm_engine.step()
  32. while llm_engine.has_unfinished_requests():
  33. (request_output, ) = llm_engine.step()
  34. (output, ) = request_output.outputs
  35. # Ensure we don't backtrack
  36. assert output.text.startswith(output_text)
  37. output_text = output.text
  38. stop_reason = output.stop_reason
  39. assert output is not None
  40. assert output_text == expected_output
  41. assert stop_reason == expected_reason
  42. def _set_async_mode(llm_engine, is_async):
  43. llm_engine.scheduler[0].use_async_output_proc = is_async
  44. def _stop_basic(llm_engine, is_async):
  45. _test_stopping(llm_engine,
  46. stop=["."],
  47. include_in_output=False,
  48. expected_output="VLLM is a 100% volunteer organization",
  49. expected_reason=".",
  50. use_async_output_proc=is_async)
  51. _test_stopping(llm_engine,
  52. stop=["."],
  53. include_in_output=True,
  54. expected_output="VLLM is a 100% volunteer organization.",
  55. expected_reason=".",
  56. use_async_output_proc=is_async)
  57. def _stop_multi_tokens(llm_engine, is_async):
  58. _test_stopping(
  59. llm_engine,
  60. stop=["group of peo", "short"],
  61. include_in_output=False,
  62. expected_output="VLLM is a 100% volunteer organization. We are a ",
  63. expected_reason="group of peo",
  64. use_async_output_proc=is_async)
  65. _test_stopping(
  66. llm_engine,
  67. stop=["group of peo", "short"],
  68. include_in_output=True,
  69. expected_output=
  70. "VLLM is a 100% volunteer organization. We are a group of peo",
  71. expected_reason="group of peo",
  72. use_async_output_proc=is_async)
  73. def _stop_partial_token(llm_engine, is_async):
  74. _test_stopping(llm_engine,
  75. stop=["gani"],
  76. include_in_output=False,
  77. expected_output="VLLM is a 100% volunteer or",
  78. expected_reason="gani",
  79. use_async_output_proc=is_async)
  80. _test_stopping(llm_engine,
  81. stop=["gani"],
  82. include_in_output=True,
  83. expected_output="VLLM is a 100% volunteer organi",
  84. expected_reason="gani",
  85. use_async_output_proc=is_async)
  86. def _stop_token_id(llm_engine, is_async):
  87. # token id 13013 => " organization"
  88. _test_stopping(llm_engine,
  89. stop_token_ids=[13013],
  90. include_in_output=False,
  91. expected_output="VLLM is a 100% volunteer",
  92. expected_reason=13013,
  93. use_async_output_proc=is_async)
  94. _test_stopping(llm_engine,
  95. stop_token_ids=[13013],
  96. include_in_output=True,
  97. expected_output="VLLM is a 100% volunteer organization",
  98. expected_reason=13013,
  99. use_async_output_proc=is_async)
  100. @pytest.mark.skip_global_cleanup
  101. def test_stop_basic(aphrodite_model):
  102. _set_async_mode(aphrodite_model.model.llm_engine, True)
  103. _stop_basic(aphrodite_model.model.llm_engine, is_async=True)
  104. _set_async_mode(aphrodite_model.model.llm_engine, False)
  105. _stop_basic(aphrodite_model.model.llm_engine, is_async=False)
  106. @pytest.mark.skip_global_cleanup
  107. def test_stop_multi_tokens(aphrodite_model):
  108. _set_async_mode(aphrodite_model.model.llm_engine, True)
  109. _stop_multi_tokens(aphrodite_model.model.llm_engine, is_async=True)
  110. _set_async_mode(aphrodite_model.model.llm_engine, False)
  111. _stop_multi_tokens(aphrodite_model.model.llm_engine, is_async=False)
  112. @pytest.mark.skip_global_cleanup
  113. def test_stop_partial_token(aphrodite_model):
  114. _set_async_mode(aphrodite_model.model.llm_engine, True)
  115. _stop_partial_token(aphrodite_model.model.llm_engine, is_async=True)
  116. _set_async_mode(aphrodite_model.model.llm_engine, False)
  117. _stop_partial_token(aphrodite_model.model.llm_engine, is_async=False)
  118. @pytest.mark.skip_global_cleanup
  119. def test_stop_token_id(aphrodite_model):
  120. _set_async_mode(aphrodite_model.model.llm_engine, True)
  121. _stop_token_id(aphrodite_model.model.llm_engine, is_async=True)
  122. _set_async_mode(aphrodite_model.model.llm_engine, False)
  123. _stop_token_id(aphrodite_model.model.llm_engine, is_async=False)