test_async_aphrodite.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import asyncio
  2. from dataclasses import dataclass
  3. import pytest
  4. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  5. @dataclass
  6. class RequestOutput:
  7. request_id: int
  8. finished: bool = False
  9. class MockEngine:
  10. def __init__(self):
  11. self.step_calls = 0
  12. self.add_request_calls = 0
  13. self.abort_request_calls = 0
  14. self.request_id = None
  15. async def step_async(self):
  16. self.step_calls += 1
  17. return [RequestOutput(
  18. request_id=self.request_id)] if self.request_id else []
  19. def generate(self, request_id):
  20. self.request_id = request_id
  21. def stop_generating(self):
  22. self.request_id = None
  23. def add_request(self, **kwargs):
  24. del kwargs # Unused
  25. self.add_request_calls += 1
  26. def abort_request(self, request_id):
  27. del request_id # Unused
  28. self.abort_request_calls += 1
  29. class MockAsyncAphrodite(AsyncAphrodite):
  30. def _init_engine(self, *args, **kwargs):
  31. return MockEngine()
  32. @pytest.mark.asyncio
  33. async def test_new_requests_event():
  34. engine = MockAsyncAphrodite(worker_use_ray=False, engine_use_ray=False)
  35. engine.start_background_loop()
  36. await asyncio.sleep(0.01)
  37. assert engine.engine.step_calls == 0
  38. await engine.add_request("1", "", None)
  39. await asyncio.sleep(0.01)
  40. assert engine.engine.add_request_calls == 1
  41. assert engine.engine.step_calls == 1
  42. await engine.add_request("2", "", None)
  43. engine.engine.generate("2")
  44. await asyncio.sleep(0)
  45. assert engine.engine.add_request_calls == 2
  46. assert engine.engine.step_calls == 2
  47. await asyncio.sleep(0)
  48. assert engine.engine.step_calls == 3
  49. engine.engine.stop_generating()
  50. await asyncio.sleep(0)
  51. assert engine.engine.step_calls == 4
  52. await asyncio.sleep(0)
  53. assert engine.engine.step_calls == 4
  54. await engine.add_request("3", "", None)
  55. await asyncio.sleep(0.01)
  56. assert engine.engine.add_request_calls == 3
  57. assert engine.engine.step_calls == 5
  58. await asyncio.sleep(0.01)
  59. assert engine.engine.add_request_calls == 3
  60. assert engine.engine.step_calls == 5