test_zmq_client.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import asyncio
  2. import tempfile
  3. import unittest
  4. import unittest.mock
  5. import uuid
  6. import pytest
  7. import pytest_asyncio
  8. from aphrodite.endpoints.openai.rpc.client import (AsyncEngineRPCClient,
  9. RPCClientClosedError)
  10. from aphrodite.endpoints.openai.rpc.server import AsyncEngineRPCServer
  11. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  12. @pytest.fixture(scope="function")
  13. def tmp_socket():
  14. with tempfile.TemporaryDirectory() as td:
  15. yield f"ipc://{td}/{uuid.uuid4()}"
  16. @pytest_asyncio.fixture(scope="function")
  17. async def dummy_server(tmp_socket, monkeypatch):
  18. dummy_engine = unittest.mock.AsyncMock()
  19. def dummy_engine_builder(*args, **kwargs):
  20. return dummy_engine
  21. with monkeypatch.context() as m:
  22. m.setattr(AsyncAphrodite, "from_engine_args", dummy_engine_builder)
  23. server = AsyncEngineRPCServer(None, rpc_path=tmp_socket)
  24. loop = asyncio.get_running_loop()
  25. server_task = loop.create_task(server.run_server_loop())
  26. try:
  27. yield server
  28. finally:
  29. server_task.cancel()
  30. server.cleanup()
  31. @pytest_asyncio.fixture(scope="function")
  32. async def client(tmp_socket):
  33. client = AsyncEngineRPCClient(rpc_path=tmp_socket)
  34. # Sanity check: the server is connected
  35. await client._wait_for_server_rpc()
  36. try:
  37. yield client
  38. finally:
  39. client.close()
  40. @pytest.mark.asyncio
  41. async def test_client_data_methods_use_timeouts(
  42. monkeypatch, dummy_server, client: AsyncEngineRPCClient
  43. ):
  44. with monkeypatch.context() as m:
  45. # Make the server _not_ reply with a model config
  46. m.setattr(dummy_server, "get_config", lambda x: None)
  47. m.setattr(client, "_data_timeout", 10)
  48. # And ensure the task completes anyway
  49. # (client.setup() invokes server.get_config())
  50. client_task = asyncio.get_running_loop().create_task(client.setup())
  51. with pytest.raises(TimeoutError, match="Server didn't reply within"):
  52. await asyncio.wait_for(client_task, timeout=0.05)
  53. @pytest.mark.asyncio
  54. async def test_client_aborts_use_timeouts(
  55. monkeypatch, dummy_server, client: AsyncEngineRPCClient
  56. ):
  57. with monkeypatch.context() as m:
  58. # Hang all abort requests
  59. m.setattr(dummy_server, "abort", lambda x: None)
  60. m.setattr(client, "_data_timeout", 10)
  61. # The client should suppress timeouts on `abort`s
  62. # and return normally, assuming the server will eventually
  63. # abort the request.
  64. client_task = asyncio.get_running_loop().create_task(
  65. client.abort("test request id"))
  66. await asyncio.wait_for(client_task, timeout=0.05)
  67. @pytest.mark.asyncio
  68. async def test_client_data_methods_reraise_exceptions(
  69. monkeypatch, dummy_server, client: AsyncEngineRPCClient
  70. ):
  71. with monkeypatch.context() as m:
  72. # Make the server raise some random exception
  73. exception = RuntimeError("Client test exception")
  74. def raiser():
  75. raise exception
  76. m.setattr(dummy_server.engine, "get_model_config", raiser)
  77. m.setattr(client, "_data_timeout", 10)
  78. client_task = asyncio.get_running_loop().create_task(client.setup())
  79. # And ensure the task completes, raising the exception
  80. with pytest.raises(RuntimeError, match=str(exception)):
  81. await asyncio.wait_for(client_task, timeout=0.05)
  82. @pytest.mark.asyncio
  83. async def test_client_errors_after_closing(
  84. monkeypatch, dummy_server, client: AsyncEngineRPCClient
  85. ):
  86. client.close()
  87. # Healthchecks and generate requests will fail with explicit errors
  88. with pytest.raises(RPCClientClosedError):
  89. await client.check_health()
  90. with pytest.raises(RPCClientClosedError):
  91. async for _ in client.generate(None, None, None):
  92. pass
  93. # But no-ops like aborting will pass
  94. await client.abort("test-request-id")
  95. await client.do_log_stats()