test_error_handling.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """Test that various errors are handled properly."""
  2. import asyncio
  3. import tempfile
  4. import time
  5. import uuid
  6. from unittest.mock import Mock
  7. import pytest
  8. from aphrodite import SamplingParams
  9. from aphrodite.common.utils import FlexibleArgumentParser
  10. from aphrodite.endpoints.openai.api_server import build_engine_client
  11. from aphrodite.endpoints.openai.args import make_arg_parser
  12. from aphrodite.engine.aphrodite_engine import AphroditeEngine
  13. from aphrodite.engine.args_tools import AsyncEngineArgs
  14. from aphrodite.engine.multiprocessing import MQEngineDeadError
  15. from aphrodite.engine.multiprocessing.engine import MQAphroditeEngine
  16. from aphrodite.lora.request import LoRARequest
  17. from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine
  18. MODEL = "google/gemma-1.1-2b-it"
  19. ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
  20. RAISED_ERROR = KeyError
  21. RAISED_VALUE = "foo"
  22. @pytest.fixture(scope="function")
  23. def tmp_socket():
  24. with tempfile.TemporaryDirectory() as td:
  25. yield f"ipc://{td}/{uuid.uuid4()}"
  26. def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
  27. # Make engine.
  28. engine = MQAphroditeEngine.from_engine_args(
  29. engine_args=engine_args,
  30. ipc_path=ipc_path)
  31. # Raise error during first forward pass.
  32. engine.engine.model_executor.execute_model = Mock(
  33. side_effect=RAISED_ERROR(RAISED_VALUE))
  34. # Run engine.
  35. engine.start()
  36. @pytest.mark.asyncio
  37. async def test_evil_forward(tmp_socket):
  38. with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
  39. ipc_path=tmp_socket,
  40. run_fn=run_with_evil_forward) as engine:
  41. client = await engine.make_client()
  42. # Server should be healthy after initial probe.
  43. await asyncio.sleep(2.0)
  44. await client.check_health()
  45. # Throws an error in first forward pass.
  46. with pytest.raises(RAISED_ERROR):
  47. async for _ in client.generate(prompt="Hello my name is",
  48. sampling_params=SamplingParams(),
  49. request_id=uuid.uuid4()):
  50. pass
  51. assert client.errored
  52. # Engine is errored, should get ENGINE_DEAD_ERROR.
  53. with pytest.raises(MQEngineDeadError):
  54. async for _ in client.generate(prompt="Hello my name is",
  55. sampling_params=SamplingParams(),
  56. request_id=uuid.uuid4()):
  57. pass
  58. assert client.errored
  59. await asyncio.sleep(1.0)
  60. with pytest.raises(RAISED_ERROR):
  61. await client.check_health()
  62. assert client.errored
  63. # Shutdown.
  64. client.close()
  65. def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
  66. ipc_path: str):
  67. # Make engine.
  68. engine = MQAphroditeEngine.from_engine_args(
  69. engine_args=engine_args,
  70. ipc_path=ipc_path)
  71. # Raise error during first forward pass.
  72. engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
  73. # Run engine.
  74. engine.start()
  75. @pytest.mark.asyncio
  76. async def test_failed_health_check(tmp_socket):
  77. with RemoteMQAphroditeEngine(
  78. engine_args=ENGINE_ARGS,
  79. ipc_path=tmp_socket,
  80. run_fn=run_with_evil_model_executor_health) as engine:
  81. client = await engine.make_client()
  82. assert client.is_running
  83. # Health probe should throw RAISED_ERROR.
  84. await asyncio.sleep(15.)
  85. with pytest.raises(RAISED_ERROR):
  86. await client.check_health()
  87. assert client.errored
  88. # Generate call should throw ENGINE_DEAD_ERROR
  89. with pytest.raises(MQEngineDeadError):
  90. async for _ in client.generate(prompt="Hello my name is",
  91. sampling_params=SamplingParams(),
  92. request_id=uuid.uuid4()):
  93. pass
  94. client.close()
  95. def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
  96. # Make engine.
  97. engine = MQAphroditeEngine.from_engine_args(
  98. engine_args=engine_args,
  99. ipc_path=ipc_path)
  100. # Raise error during abort call.
  101. engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
  102. # Run engine.
  103. engine.start()
  104. @pytest.mark.asyncio
  105. async def test_failed_abort(tmp_socket):
  106. with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
  107. ipc_path=tmp_socket,
  108. run_fn=run_with_evil_abort) as engine:
  109. client = await engine.make_client()
  110. assert client.is_running
  111. # Firsh check health should work.
  112. await client.check_health()
  113. # Trigger an abort on the client side.
  114. async def bad_abort_after_2s():
  115. await asyncio.sleep(2.0)
  116. await client.abort(request_id="foo")
  117. # Trigger an abort in 2s from now.
  118. abort_task = asyncio.create_task(bad_abort_after_2s())
  119. # Exception in abort() will happen during this generation.
  120. # This will kill the engine and should return ENGINE_DEAD_ERROR
  121. # with reference to the original KeyError("foo")
  122. with pytest.raises(MQEngineDeadError) as execinfo:
  123. async for _ in client.generate(
  124. prompt="Hello my name is",
  125. sampling_params=SamplingParams(max_tokens=2000),
  126. request_id=uuid.uuid4()):
  127. pass
  128. assert "KeyError" in repr(execinfo.value)
  129. assert client.errored
  130. await abort_task
  131. # This should raise the original error.
  132. with pytest.raises(RAISED_ERROR):
  133. await client.check_health()
  134. client.close()
  135. @pytest.mark.asyncio
  136. async def test_bad_request(tmp_socket):
  137. with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
  138. ipc_path=tmp_socket) as engine:
  139. client = await engine.make_client()
  140. # Invalid request should fail, but not crash the server.
  141. with pytest.raises(ValueError):
  142. async for _ in client.generate(prompt="Hello my name is",
  143. sampling_params=SamplingParams(),
  144. request_id="abcd-1",
  145. lora_request=LoRARequest(
  146. "invalid-lora", 1,
  147. "invalid-path")):
  148. pass
  149. # This request should be okay.
  150. async for _ in client.generate(prompt="Hello my name is",
  151. sampling_params=SamplingParams(),
  152. request_id="abcd-2"):
  153. pass
  154. # Shutdown.
  155. client.close()
  156. @pytest.mark.asyncio
  157. async def test_mp_crash_detection(monkeypatch):
  158. parser = FlexibleArgumentParser(
  159. description="Aphrodite's remote OpenAI server.")
  160. parser = make_arg_parser(parser)
  161. args = parser.parse_args([])
  162. # When AphroditeEngine is loaded, it will crash.
  163. def mock_init():
  164. raise ValueError
  165. monkeypatch.setattr(AphroditeEngine, "__init__", mock_init)
  166. start = time.perf_counter()
  167. async with build_engine_client(args):
  168. pass
  169. end = time.perf_counter()
  170. assert end - start < 60, (
  171. "Expected Aphrodite to gracefully shutdown in <60s "
  172. "if there is an error in the startup.")
  173. @pytest.mark.asyncio
  174. async def test_mp_cuda_init():
  175. # it should not crash, when cuda is initialized
  176. # in the API server process
  177. import torch
  178. torch.cuda.init()
  179. parser = FlexibleArgumentParser(
  180. description="Aphrodite's remote OpenAI server.")
  181. parser = make_arg_parser(parser)
  182. args = parser.parse_args([])
  183. async with build_engine_client(args):
  184. pass