123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- """Test that various errors are handled properly."""
- import asyncio
- import tempfile
- import time
- import uuid
- from unittest.mock import Mock
- import pytest
- from aphrodite import SamplingParams
- from aphrodite.common.utils import FlexibleArgumentParser
- from aphrodite.endpoints.openai.api_server import build_engine_client
- from aphrodite.endpoints.openai.args import make_arg_parser
- from aphrodite.engine.aphrodite_engine import AphroditeEngine
- from aphrodite.engine.args_tools import AsyncEngineArgs
- from aphrodite.engine.multiprocessing import MQEngineDeadError
- from aphrodite.engine.multiprocessing.engine import MQAphroditeEngine
- from aphrodite.lora.request import LoRARequest
- from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine
- MODEL = "google/gemma-1.1-2b-it"
- ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
- RAISED_ERROR = KeyError
- RAISED_VALUE = "foo"
- @pytest.fixture(scope="function")
- def tmp_socket():
- with tempfile.TemporaryDirectory() as td:
- yield f"ipc://{td}/{uuid.uuid4()}"
- def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
- # Make engine.
- engine = MQAphroditeEngine.from_engine_args(
- engine_args=engine_args,
- ipc_path=ipc_path)
- # Raise error during first forward pass.
- engine.engine.model_executor.execute_model = Mock(
- side_effect=RAISED_ERROR(RAISED_VALUE))
- # Run engine.
- engine.start()
- @pytest.mark.asyncio
- async def test_evil_forward(tmp_socket):
- with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
- ipc_path=tmp_socket,
- run_fn=run_with_evil_forward) as engine:
- client = await engine.make_client()
- # Server should be healthy after initial probe.
- await asyncio.sleep(2.0)
- await client.check_health()
- # Throws an error in first forward pass.
- with pytest.raises(RAISED_ERROR):
- async for _ in client.generate(prompt="Hello my name is",
- sampling_params=SamplingParams(),
- request_id=uuid.uuid4()):
- pass
- assert client.errored
- # Engine is errored, should get ENGINE_DEAD_ERROR.
- with pytest.raises(MQEngineDeadError):
- async for _ in client.generate(prompt="Hello my name is",
- sampling_params=SamplingParams(),
- request_id=uuid.uuid4()):
- pass
- assert client.errored
- await asyncio.sleep(1.0)
- with pytest.raises(RAISED_ERROR):
- await client.check_health()
- assert client.errored
- # Shutdown.
- client.close()
- def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
- ipc_path: str):
- # Make engine.
- engine = MQAphroditeEngine.from_engine_args(
- engine_args=engine_args,
- ipc_path=ipc_path)
- # Raise error during first forward pass.
- engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
- # Run engine.
- engine.start()
- @pytest.mark.asyncio
- async def test_failed_health_check(tmp_socket):
- with RemoteMQAphroditeEngine(
- engine_args=ENGINE_ARGS,
- ipc_path=tmp_socket,
- run_fn=run_with_evil_model_executor_health) as engine:
- client = await engine.make_client()
- assert client.is_running
- # Health probe should throw RAISED_ERROR.
- await asyncio.sleep(15.)
- with pytest.raises(RAISED_ERROR):
- await client.check_health()
- assert client.errored
- # Generate call should throw ENGINE_DEAD_ERROR
- with pytest.raises(MQEngineDeadError):
- async for _ in client.generate(prompt="Hello my name is",
- sampling_params=SamplingParams(),
- request_id=uuid.uuid4()):
- pass
- client.close()
- def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
- # Make engine.
- engine = MQAphroditeEngine.from_engine_args(
- engine_args=engine_args,
- ipc_path=ipc_path)
- # Raise error during abort call.
- engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
- # Run engine.
- engine.start()
- @pytest.mark.asyncio
- async def test_failed_abort(tmp_socket):
- with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
- ipc_path=tmp_socket,
- run_fn=run_with_evil_abort) as engine:
- client = await engine.make_client()
- assert client.is_running
- # Firsh check health should work.
- await client.check_health()
- # Trigger an abort on the client side.
- async def bad_abort_after_2s():
- await asyncio.sleep(2.0)
- await client.abort(request_id="foo")
- # Trigger an abort in 2s from now.
- abort_task = asyncio.create_task(bad_abort_after_2s())
- # Exception in abort() will happen during this generation.
- # This will kill the engine and should return ENGINE_DEAD_ERROR
- # with reference to the original KeyError("foo")
- with pytest.raises(MQEngineDeadError) as execinfo:
- async for _ in client.generate(
- prompt="Hello my name is",
- sampling_params=SamplingParams(max_tokens=2000),
- request_id=uuid.uuid4()):
- pass
- assert "KeyError" in repr(execinfo.value)
- assert client.errored
- await abort_task
- # This should raise the original error.
- with pytest.raises(RAISED_ERROR):
- await client.check_health()
- client.close()
- @pytest.mark.asyncio
- async def test_bad_request(tmp_socket):
- with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
- ipc_path=tmp_socket) as engine:
- client = await engine.make_client()
- # Invalid request should fail, but not crash the server.
- with pytest.raises(ValueError):
- async for _ in client.generate(prompt="Hello my name is",
- sampling_params=SamplingParams(),
- request_id="abcd-1",
- lora_request=LoRARequest(
- "invalid-lora", 1,
- "invalid-path")):
- pass
- # This request should be okay.
- async for _ in client.generate(prompt="Hello my name is",
- sampling_params=SamplingParams(),
- request_id="abcd-2"):
- pass
- # Shutdown.
- client.close()
- @pytest.mark.asyncio
- async def test_mp_crash_detection(monkeypatch):
- parser = FlexibleArgumentParser(
- description="Aphrodite's remote OpenAI server.")
- parser = make_arg_parser(parser)
- args = parser.parse_args([])
- # When AphroditeEngine is loaded, it will crash.
- def mock_init():
- raise ValueError
- monkeypatch.setattr(AphroditeEngine, "__init__", mock_init)
- start = time.perf_counter()
- async with build_engine_client(args):
- pass
- end = time.perf_counter()
- assert end - start < 60, (
- "Expected Aphrodite to gracefully shutdown in <60s "
- "if there is an error in the startup.")
- @pytest.mark.asyncio
- async def test_mp_cuda_init():
- # it should not crash, when cuda is initialized
- # in the API server process
- import torch
- torch.cuda.init()
- parser = FlexibleArgumentParser(
- description="Aphrodite's remote OpenAI server.")
- parser = make_arg_parser(parser)
- args = parser.parse_args([])
- async with build_engine_client(args):
- pass
|