test_async_aphrodite.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import asyncio
  2. from dataclasses import dataclass
  3. import pytest
  4. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  5. @dataclass
  6. class RequestOutput:
  7. request_id: int
  8. finished: bool = False
  9. class MockEngine:
  10. def __init__(self):
  11. self.step_calls = 0
  12. self.add_request_calls = 0
  13. self.abort_request_calls = 0
  14. self.request_id = None
  15. async def step_async(self):
  16. self.step_calls += 1
  17. return [RequestOutput(
  18. request_id=self.request_id)] if self.request_id else []
  19. async def encode_request_async(self, *args, **kwargs):
  20. pass
  21. def generate(self, request_id):
  22. self.request_id = request_id
  23. def stop_generating(self):
  24. self.request_id = None
  25. def add_request(self, **kwargs):
  26. del kwargs # Unused
  27. self.add_request_calls += 1
  28. async def add_request_async(self, **kwargs):
  29. self.add_request_calls += 1
  30. return
  31. def abort_request(self, request_id):
  32. del request_id # Unused
  33. self.abort_request_calls += 1
  34. def has_unfinished_requests(self):
  35. return self.request_id is not None
  36. class MockAsyncAphrodite(AsyncAphrodite):
  37. def _init_engine(self, *args, **kwargs):
  38. return MockEngine()
  39. @pytest.mark.asyncio
  40. async def test_new_requests_event():
  41. engine = MockAsyncAphrodite(worker_use_ray=False, engine_use_ray=False)
  42. engine.start_background_loop()
  43. await asyncio.sleep(0.01)
  44. assert engine.engine.step_calls == 0
  45. await engine.add_request("1", "", None)
  46. await asyncio.sleep(0.01)
  47. assert engine.engine.add_request_calls == 1
  48. assert engine.engine.step_calls == 1
  49. await engine.add_request("2", "", None)
  50. engine.engine.generate("2")
  51. await asyncio.sleep(0)
  52. await asyncio.sleep(0)
  53. assert engine.engine.add_request_calls == 2
  54. assert engine.engine.step_calls >= 2
  55. await asyncio.sleep(0.001)
  56. assert engine.engine.step_calls >= 3
  57. engine.engine.stop_generating()
  58. await asyncio.sleep(0.001)
  59. old_step_calls = engine.engine.step_calls
  60. await asyncio.sleep(0.001)
  61. assert engine.engine.step_calls == old_step_calls
  62. await engine.add_request("3", "", None)
  63. await asyncio.sleep(0.01)
  64. assert engine.engine.add_request_calls == 3
  65. assert engine.engine.step_calls == old_step_calls + 1
  66. await asyncio.sleep(0.01)
  67. assert engine.engine.add_request_calls == 3
  68. assert engine.engine.step_calls == old_step_calls + 1