123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- import asyncio
- import os
- from dataclasses import dataclass
- import pytest
- import torch
- from aphrodite import SamplingParams
- from aphrodite.common.config import ParallelConfig
- from aphrodite.engine.async_aphrodite import AsyncAphrodite, AsyncEngineArgs
- from ..utils import wait_for_gpu_memory_to_clear
- @dataclass
- class RequestOutput:
- request_id: int
- finished: bool = False
- class MockEngine:
- def __init__(self):
- self.step_calls = 0
- self.add_request_calls = 0
- self.abort_request_calls = 0
- self.request_id = None
- # Ugly, remove dependency when possible
- self.parallel_config = ParallelConfig(1, 1, False)
- async def step_async(self, virtual_engine):
- # PP size is 1, ignore virtual engine
- self.step_calls += 1
- return [RequestOutput(
- request_id=self.request_id)] if self.request_id else []
- async def process_model_inputs_async(self, *args, **kwargs):
- pass
- async def stop_remote_worker_execution_loop_async(self):
- pass
- def generate(self, request_id):
- self.request_id = request_id
- def stop_generating(self):
- self.request_id = None
- def add_request(self, **kwargs):
- del kwargs # Unused
- self.add_request_calls += 1
- print(f'Request calls: {self.add_request_calls}')
- async def add_request_async(self, **kwargs):
- self.add_request_calls += 1
- return
- def abort_request(self, request_id):
- del request_id # Unused
- self.abort_request_calls += 1
- def has_unfinished_requests(self):
- return self.request_id is not None
- def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
- return self.request_id is not None
- class MockAsyncAphrodite(AsyncAphrodite):
- def _init_engine(self, *args, **kwargs):
- return MockEngine()
- @pytest.mark.asyncio
- async def test_new_requests_event():
- engine = MockAsyncAphrodite(worker_use_ray=False, engine_use_ray=False)
- engine.start_background_loop()
- await asyncio.sleep(0.01)
- assert engine.engine.step_calls == 0
- await engine.add_request("1", "", None)
- await asyncio.sleep(0.01)
- assert engine.engine.add_request_calls == 1
- assert engine.engine.step_calls == 1
- await engine.add_request("2", "", None)
- engine.engine.generate("2")
- await asyncio.sleep(0)
- await asyncio.sleep(0)
- await asyncio.sleep(0)
- assert engine.engine.add_request_calls == 2
- assert engine.engine.step_calls >= 2
- await asyncio.sleep(0.001)
- assert engine.engine.step_calls >= 3
- engine.engine.stop_generating()
- await asyncio.sleep(0.001)
- old_step_calls = engine.engine.step_calls
- await asyncio.sleep(0.001)
- assert engine.engine.step_calls == old_step_calls
- await engine.add_request("3", "", None)
- await asyncio.sleep(0.01)
- assert engine.engine.add_request_calls == 3
- assert engine.engine.step_calls == old_step_calls + 1
- await asyncio.sleep(0.01)
- assert engine.engine.add_request_calls == 3
- assert engine.engine.step_calls == old_step_calls + 1
- # Allow deprecated engine_use_ray to not raise exception
- os.environ["APHRODITE_ALLOW_ENGINE_USE_RAY"] = "1"
- engine = MockAsyncAphrodite(worker_use_ray=True, engine_use_ray=True)
- assert engine.get_model_config() is not None
- assert engine.get_tokenizer() is not None
- assert engine.get_decoding_config() is not None
- os.environ.pop("APHRODITE_ALLOW_ENGINE_USE_RAY")
- def test_asyncio_run():
- wait_for_gpu_memory_to_clear(
- devices=list(range(torch.cuda.device_count())),
- threshold_bytes=2 * 2**30,
- timeout_s=60,
- )
- engine = AsyncAphrodite.from_engine_args(
- AsyncEngineArgs(model="facebook/opt-125m"))
- async def run(prompt: str):
- sampling_params = SamplingParams(
- temperature=0,
- max_tokens=32,
- )
- async for output in engine.generate(prompt,
- sampling_params,
- request_id=prompt):
- final_output = output
- return final_output
- async def generate():
- return await asyncio.gather(
- run("test0"),
- run("test1"),
- )
- results = asyncio.run(generate())
- assert len(results) == 2
|