123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import asyncio
- import tempfile
- import unittest
- import unittest.mock
- import uuid
- import pytest
- import pytest_asyncio
- from aphrodite.endpoints.openai.rpc.client import (AsyncEngineRPCClient,
- RPCClientClosedError)
- from aphrodite.endpoints.openai.rpc.server import AsyncEngineRPCServer
- from aphrodite.engine.async_aphrodite import AsyncAphrodite
- @pytest.fixture(scope="function")
- def tmp_socket():
- with tempfile.TemporaryDirectory() as td:
- yield f"ipc://{td}/{uuid.uuid4()}"
- @pytest_asyncio.fixture(scope="function")
- async def dummy_server(tmp_socket, monkeypatch):
- dummy_engine = unittest.mock.AsyncMock()
- def dummy_engine_builder(*args, **kwargs):
- return dummy_engine
- with monkeypatch.context() as m:
- m.setattr(AsyncAphrodite, "from_engine_args", dummy_engine_builder)
- server = AsyncEngineRPCServer(None, rpc_path=tmp_socket)
- loop = asyncio.get_running_loop()
- server_task = loop.create_task(server.run_server_loop())
- try:
- yield server
- finally:
- server_task.cancel()
- server.cleanup()
- @pytest_asyncio.fixture(scope="function")
- async def client(tmp_socket):
- client = AsyncEngineRPCClient(rpc_path=tmp_socket)
- # Sanity check: the server is connected
- await client._wait_for_server_rpc()
- try:
- yield client
- finally:
- client.close()
- @pytest.mark.asyncio
- async def test_client_data_methods_use_timeouts(
- monkeypatch, dummy_server, client: AsyncEngineRPCClient
- ):
- with monkeypatch.context() as m:
- # Make the server _not_ reply with a model config
- m.setattr(dummy_server, "get_config", lambda x: None)
- m.setattr(client, "_data_timeout", 10)
- # And ensure the task completes anyway
- # (client.setup() invokes server.get_config())
- client_task = asyncio.get_running_loop().create_task(client.setup())
- with pytest.raises(TimeoutError, match="Server didn't reply within"):
- await asyncio.wait_for(client_task, timeout=0.05)
- @pytest.mark.asyncio
- async def test_client_aborts_use_timeouts(
- monkeypatch, dummy_server, client: AsyncEngineRPCClient
- ):
- with monkeypatch.context() as m:
- # Hang all abort requests
- m.setattr(dummy_server, "abort", lambda x: None)
- m.setattr(client, "_data_timeout", 10)
- # The client should suppress timeouts on `abort`s
- # and return normally, assuming the server will eventually
- # abort the request.
- client_task = asyncio.get_running_loop().create_task(
- client.abort("test request id"))
- await asyncio.wait_for(client_task, timeout=0.05)
- @pytest.mark.asyncio
- async def test_client_data_methods_reraise_exceptions(
- monkeypatch, dummy_server, client: AsyncEngineRPCClient
- ):
- with monkeypatch.context() as m:
- # Make the server raise some random exception
- exception = RuntimeError("Client test exception")
- def raiser():
- raise exception
- m.setattr(dummy_server.engine, "get_model_config", raiser)
- m.setattr(client, "_data_timeout", 10)
- client_task = asyncio.get_running_loop().create_task(client.setup())
- # And ensure the task completes, raising the exception
- with pytest.raises(RuntimeError, match=str(exception)):
- await asyncio.wait_for(client_task, timeout=0.05)
- @pytest.mark.asyncio
- async def test_client_errors_after_closing(
- monkeypatch, dummy_server, client: AsyncEngineRPCClient
- ):
- client.close()
- # Healthchecks and generate requests will fail with explicit errors
- with pytest.raises(RPCClientClosedError):
- await client.check_health()
- with pytest.raises(RPCClientClosedError):
- async for _ in client.generate(None, None, None):
- pass
- # But no-ops like aborting will pass
- await client.abort("test-request-id")
- await client.do_log_stats()
|