test_request_tracker.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import pytest
  2. from aphrodite.common.outputs import RequestOutput
  3. from aphrodite.engine.async_aphrodite import RequestTracker
  4. @pytest.mark.asyncio
  5. async def test_request_tracker():
  6. tracker = RequestTracker()
  7. stream_1 = tracker.add_request("1")
  8. assert tracker.new_requests_event.is_set()
  9. await tracker.wait_for_new_requests()
  10. new, finished = tracker.get_new_and_finished_requests()
  11. assert not tracker.new_requests_event.is_set()
  12. assert len(new) == 1
  13. assert new[0]["request_id"] == "1"
  14. assert not finished
  15. assert not stream_1.finished
  16. stream_2 = tracker.add_request("2")
  17. stream_3 = tracker.add_request("3")
  18. assert tracker.new_requests_event.is_set()
  19. await tracker.wait_for_new_requests()
  20. new, finished = tracker.get_new_and_finished_requests()
  21. assert not tracker.new_requests_event.is_set()
  22. assert len(new) == 2
  23. assert new[0]["request_id"] == "2"
  24. assert new[1]["request_id"] == "3"
  25. assert not finished
  26. assert not stream_2.finished
  27. assert not stream_3.finished
  28. # request_ids must be unique
  29. with pytest.raises(KeyError):
  30. tracker.add_request("1")
  31. assert not tracker.new_requests_event.is_set()
  32. tracker.abort_request("1")
  33. new, finished = tracker.get_new_and_finished_requests()
  34. assert len(finished) == 1
  35. assert "1" in finished
  36. assert not new
  37. assert stream_1.finished
  38. stream_4 = tracker.add_request("4")
  39. tracker.abort_request("4")
  40. assert tracker.new_requests_event.is_set()
  41. await tracker.wait_for_new_requests()
  42. new, finished = tracker.get_new_and_finished_requests()
  43. assert len(finished) == 1
  44. assert "4" in finished
  45. assert not new
  46. assert stream_4.finished
  47. stream_5 = tracker.add_request("5")
  48. assert tracker.new_requests_event.is_set()
  49. tracker.process_request_output(
  50. RequestOutput("2", "output", [], [], [], finished=True))
  51. await tracker.wait_for_new_requests()
  52. new, finished = tracker.get_new_and_finished_requests()
  53. assert not tracker.new_requests_event.is_set()
  54. assert len(finished) == 1
  55. assert "2" in finished
  56. assert len(new) == 1
  57. assert new[0]["request_id"] == "5"
  58. assert stream_2.finished
  59. assert not stream_5.finished