1
0

test_load.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """Test that the MQLLMEngine is able to handle 10k concurrent requests."""
  2. import asyncio
  3. import tempfile
  4. import uuid
  5. import pytest
  6. from aphrodite.engine.args_tools import AsyncEngineArgs
  7. from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine, generate
  8. MODEL = "google/gemma-1.1-2b-it"
  9. NUM_EXPECTED_TOKENS = 10
  10. NUM_REQUESTS = 10000
  11. # Scenarios to test for num generated token.
  12. ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
  13. @pytest.fixture(scope="function")
  14. def tmp_socket():
  15. with tempfile.TemporaryDirectory() as td:
  16. yield f"ipc://{td}/{uuid.uuid4()}"
  17. @pytest.mark.asyncio
  18. async def test_load(tmp_socket):
  19. with RemoteMQAphroditeEngine(
  20. engine_args=ENGINE_ARGS,
  21. ipc_path=tmp_socket) as engine:
  22. client = await engine.make_client()
  23. request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
  24. # Create concurrent requests.
  25. tasks = []
  26. for request_id in request_ids:
  27. tasks.append(
  28. asyncio.create_task(
  29. generate(client, request_id, NUM_EXPECTED_TOKENS)))
  30. # Confirm that we got all the EXPECTED tokens from the requests.
  31. failed_request_id = None
  32. tokens = None
  33. for task in tasks:
  34. num_generated_tokens, request_id = await task
  35. if (num_generated_tokens != NUM_EXPECTED_TOKENS
  36. and failed_request_id is None):
  37. failed_request_id = request_id
  38. tokens = num_generated_tokens
  39. assert failed_request_id is None, (
  40. f"{failed_request_id} generated {tokens} but "
  41. f"expected {NUM_EXPECTED_TOKENS}")
  42. # Shutdown.
  43. client.close()