test_abort.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """Test that aborting is handled properly."""
  2. import asyncio
  3. import tempfile
  4. import uuid
  5. import pytest
  6. from aphrodite.engine.args_tools import AsyncEngineArgs
  7. from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine, generate
  8. MODEL = "google/gemma-1.1-2b-it"
  9. ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
  10. RAISED_ERROR = KeyError
  11. RAISED_VALUE = "foo"
  12. EXPECTED_TOKENS = 250
  13. @pytest.fixture(scope="function")
  14. def tmp_socket():
  15. with tempfile.TemporaryDirectory() as td:
  16. yield f"ipc://{td}/{uuid.uuid4()}"
  17. @pytest.mark.asyncio
  18. async def test_abort(tmp_socket):
  19. with RemoteMQAphroditeEngine(
  20. engine_args=ENGINE_ARGS,
  21. ipc_path=tmp_socket) as engine:
  22. client = await engine.make_client()
  23. request_id_to_be_aborted = "request-aborted"
  24. request_ids_a = [f"request-a-{idx}" for idx in range(10)]
  25. request_ids_b = [f"request-b-{idx}" for idx in range(10)]
  26. # Requests started before one to be aborted.
  27. tasks = []
  28. for request_id in request_ids_a:
  29. tasks.append(
  30. asyncio.create_task(
  31. generate(client, request_id, EXPECTED_TOKENS)))
  32. # Aborted.
  33. task_aborted = asyncio.create_task(
  34. generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
  35. # Requests started after one to be aborted.
  36. for request_id in request_ids_b:
  37. tasks.append(
  38. asyncio.create_task(
  39. generate(client, request_id, EXPECTED_TOKENS)))
  40. # Actually abort.
  41. await asyncio.sleep(0.5)
  42. await client.abort(request_id_to_be_aborted)
  43. # Confirm that we got all the EXPECTED tokens from the requests.
  44. for task in tasks:
  45. count, request_id = await task
  46. assert count == EXPECTED_TOKENS, (
  47. f"{request_id} generated only {count} tokens")
  48. # Cancel task (this will hang indefinitely if not).
  49. task_aborted.cancel()
  50. # Shutdown.
  51. client.close()