1
0

test_multiproc_workers.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import asyncio
  2. from concurrent.futures import ThreadPoolExecutor
  3. from functools import partial
  4. from time import sleep
  5. from typing import Any, List, Tuple
  6. import pytest
  7. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  8. ResultHandler,
  9. WorkerMonitor)
  10. class DummyWorker:
  11. """Dummy version of aphrodite.task_handler.worker.Worker"""
  12. def __init__(self, rank: int):
  13. self.rank = rank
  14. def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
  15. sleep(0.05)
  16. if isinstance(worker_input, Exception):
  17. # simulate error case
  18. raise worker_input
  19. return self.rank, input
  20. def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
  21. result_handler = ResultHandler()
  22. workers = [
  23. ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
  24. for rank in range(8)
  25. ]
  26. worker_monitor = WorkerMonitor(workers, result_handler)
  27. assert not worker_monitor.is_alive()
  28. result_handler.start()
  29. worker_monitor.start()
  30. assert worker_monitor.is_alive()
  31. return workers, worker_monitor
  32. def test_local_workers() -> None:
  33. """Test workers with sync task submission"""
  34. workers, worker_monitor = _start_workers()
  35. def execute_workers(worker_input: str) -> None:
  36. worker_outputs = [
  37. worker.execute_method("worker_method", worker_input)
  38. for worker in workers
  39. ]
  40. for rank, output in enumerate(worker_outputs):
  41. assert output.get() == (rank, input)
  42. executor = ThreadPoolExecutor(max_workers=4)
  43. # Test concurrent submission from different threads
  44. futures = [
  45. executor.submit(partial(execute_workers, f"thread {thread_num}"))
  46. for thread_num in range(4)
  47. ]
  48. for future in futures:
  49. future.result()
  50. # Test error case
  51. exception = ValueError("fake error")
  52. result = workers[0].execute_method("worker_method", exception)
  53. try:
  54. result.get()
  55. pytest.fail("task should have failed")
  56. except Exception as e:
  57. assert isinstance(e, ValueError)
  58. assert str(e) == "fake error"
  59. # Test cleanup when a worker fails
  60. assert worker_monitor.is_alive()
  61. workers[3].process.kill()
  62. # Other workers should get shut down here
  63. worker_monitor.join(2)
  64. # Ensure everything is stopped
  65. assert not worker_monitor.is_alive()
  66. assert all(not worker.process.is_alive() for worker in workers)
  67. # Further attempts to submit tasks should fail
  68. try:
  69. _result = workers[0].execute_method("worker_method", "test")
  70. pytest.fail("task should fail once workers have been shut down")
  71. except Exception as e:
  72. assert isinstance(e, ChildProcessError)
  73. def test_local_workers_clean_shutdown() -> None:
  74. """Test clean shutdown"""
  75. workers, worker_monitor = _start_workers()
  76. assert worker_monitor.is_alive()
  77. assert all(worker.process.is_alive() for worker in workers)
  78. # Clean shutdown
  79. worker_monitor.close()
  80. worker_monitor.join(5)
  81. # Ensure everything is stopped
  82. assert not worker_monitor.is_alive()
  83. assert all(not worker.process.is_alive() for worker in workers)
  84. # Further attempts to submit tasks should fail
  85. try:
  86. _result = workers[0].execute_method("worker_method", "test")
  87. pytest.fail("task should fail once workers have been shut down")
  88. except Exception as e:
  89. assert isinstance(e, ChildProcessError)
  90. @pytest.mark.asyncio
  91. async def test_local_workers_async() -> None:
  92. """Test local workers with async task submission"""
  93. workers, worker_monitor = _start_workers()
  94. async def execute_workers(worker_input: str) -> None:
  95. worker_coros = [
  96. worker.execute_method_async("worker_method", worker_input)
  97. for worker in workers
  98. ]
  99. results = await asyncio.gather(*worker_coros)
  100. for rank, result in enumerate(results):
  101. assert result == (rank, input)
  102. tasks = [
  103. asyncio.create_task(execute_workers(f"task {task_num}"))
  104. for task_num in range(4)
  105. ]
  106. for task in tasks:
  107. await task
  108. # Test error case
  109. exception = ValueError("fake error")
  110. try:
  111. _result = await workers[0].execute_method_async(
  112. "worker_method", exception)
  113. pytest.fail("task should have failed")
  114. except Exception as e:
  115. assert isinstance(e, ValueError)
  116. assert str(e) == "fake error"
  117. # Test cleanup when a worker fails
  118. assert worker_monitor.is_alive()
  119. workers[3].process.kill()
  120. # Other workers should get shut down here
  121. worker_monitor.join(2)
  122. # Ensure everything is stopped
  123. assert not worker_monitor.is_alive()
  124. assert all(not worker.process.is_alive() for worker in workers)
  125. # Further attempts to submit tasks should fail
  126. try:
  127. _result = await workers[0].execute_method_async(
  128. "worker_method", "test")
  129. pytest.fail("task should fail once workers have been shut down")
  130. except Exception as e:
  131. assert isinstance(e, ChildProcessError)