1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- """Test that aborting is handled properly."""
- import asyncio
- import tempfile
- import uuid
- import pytest
- from aphrodite.engine.args_tools import AsyncEngineArgs
- from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine, generate
- MODEL = "google/gemma-1.1-2b-it"
- ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
- RAISED_ERROR = KeyError
- RAISED_VALUE = "foo"
- EXPECTED_TOKENS = 250
- @pytest.fixture(scope="function")
- def tmp_socket():
- with tempfile.TemporaryDirectory() as td:
- yield f"ipc://{td}/{uuid.uuid4()}"
- @pytest.mark.asyncio
- async def test_abort(tmp_socket):
- with RemoteMQAphroditeEngine(
- engine_args=ENGINE_ARGS,
- ipc_path=tmp_socket) as engine:
- client = await engine.make_client()
- request_id_to_be_aborted = "request-aborted"
- request_ids_a = [f"request-a-{idx}" for idx in range(10)]
- request_ids_b = [f"request-b-{idx}" for idx in range(10)]
- # Requests started before one to be aborted.
- tasks = []
- for request_id in request_ids_a:
- tasks.append(
- asyncio.create_task(
- generate(client, request_id, EXPECTED_TOKENS)))
- # Aborted.
- task_aborted = asyncio.create_task(
- generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
- # Requests started after one to be aborted.
- for request_id in request_ids_b:
- tasks.append(
- asyncio.create_task(
- generate(client, request_id, EXPECTED_TOKENS)))
- # Actually abort.
- await asyncio.sleep(0.5)
- await client.abort(request_id_to_be_aborted)
- # Confirm that we got all the EXPECTED tokens from the requests.
- for task in tasks:
- count, request_id = await task
- assert count == EXPECTED_TOKENS, (
- f"{request_id} generated only {count} tokens")
- # Cancel task (this will hang indefinitely if not).
- task_aborted.cancel()
- # Shutdown.
- client.close()
|