multiproc_worker_utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import asyncio
  2. import multiprocessing
  3. import os
  4. import sys
  5. import threading
  6. import traceback
  7. import uuid
  8. from dataclasses import dataclass
  9. from multiprocessing import Queue
  10. from multiprocessing.connection import wait
  11. from multiprocessing.process import BaseProcess
  12. from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
  13. TypeVar, Union)
  14. from loguru import logger
  15. import aphrodite.common.envs as envs
  16. T = TypeVar('T')
  17. _TERMINATE = "TERMINATE" # sentinel
  18. # ANSI color codes
  19. CYAN = '\033[1;36m'
  20. RESET = '\033[0;0m'
  21. JOIN_TIMEOUT_S = 2
  22. # Use dedicated multiprocess context for workers.
  23. # Both spawn and fork work
  24. mp_method = envs.APHRODITE_WORKER_MULTIPROC_METHOD
  25. mp = multiprocessing.get_context(mp_method)
  26. @dataclass
  27. class Result(Generic[T]):
  28. """Result of task dispatched to worker"""
  29. task_id: uuid.UUID
  30. value: Optional[T] = None
  31. exception: Optional[BaseException] = None
  32. class ResultFuture(threading.Event, Generic[T]):
  33. """Synchronous future for non-async case"""
  34. def __init__(self):
  35. super().__init__()
  36. self.result: Optional[Result[T]] = None
  37. def set_result(self, result: Result[T]):
  38. self.result = result
  39. self.set()
  40. def get(self) -> T:
  41. self.wait()
  42. assert self.result is not None
  43. if self.result.exception is not None:
  44. raise self.result.exception
  45. return self.result.value # type: ignore[return-value]
  46. def _set_future_result(future: Union[ResultFuture, asyncio.Future],
  47. result: Result):
  48. if isinstance(future, ResultFuture):
  49. future.set_result(result)
  50. return
  51. loop = future.get_loop()
  52. if not loop.is_closed():
  53. if result.exception is not None:
  54. loop.call_soon_threadsafe(future.set_exception, result.exception)
  55. else:
  56. loop.call_soon_threadsafe(future.set_result, result.value)
  57. class ResultHandler(threading.Thread):
  58. """Handle results from all workers (in background thread)"""
  59. def __init__(self) -> None:
  60. super().__init__(daemon=True)
  61. self.result_queue = mp.Queue()
  62. self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
  63. def run(self):
  64. for result in iter(self.result_queue.get, _TERMINATE):
  65. future = self.tasks.pop(result.task_id)
  66. _set_future_result(future, result)
  67. # Ensure that all waiters will receive an exception
  68. for task_id, future in self.tasks.items():
  69. _set_future_result(
  70. future,
  71. Result(task_id=task_id,
  72. exception=ChildProcessError("worker died")))
  73. def close(self):
  74. self.result_queue.put(_TERMINATE)
  75. class WorkerMonitor(threading.Thread):
  76. """Monitor worker status (in background thread)"""
  77. def __init__(self, workers: List['ProcessWorkerWrapper'],
  78. result_handler: ResultHandler):
  79. super().__init__(daemon=True)
  80. self.workers = workers
  81. self.result_handler = result_handler
  82. self._close = False
  83. def run(self) -> None:
  84. # Blocks until any worker exits
  85. dead_sentinels = wait([w.process.sentinel for w in self.workers])
  86. if not self._close:
  87. self._close = True
  88. # Kill / cleanup all workers
  89. for worker in self.workers:
  90. process = worker.process
  91. if process.sentinel in dead_sentinels:
  92. process.join(JOIN_TIMEOUT_S)
  93. if process.exitcode is not None and process.exitcode != 0:
  94. logger.error(f"Worker {process.name} pid {process.pid} "
  95. f"died, exit code: {process.exitcode}")
  96. # Cleanup any remaining workers
  97. if logger:
  98. logger.info("Killing local Aphrodite worker processes")
  99. for worker in self.workers:
  100. worker.kill_worker()
  101. # Must be done after worker task queues are all closed
  102. self.result_handler.close()
  103. for worker in self.workers:
  104. worker.process.join(JOIN_TIMEOUT_S)
  105. def close(self):
  106. if self._close:
  107. return
  108. self._close = True
  109. logger.info("Terminating local Aphrodite worker processes")
  110. for worker in self.workers:
  111. worker.terminate_worker()
  112. # Must be done after worker task queues are all closed
  113. self.result_handler.close()
  114. class ProcessWorkerWrapper:
  115. """Local process wrapper for aphrodite.worker.Worker,
  116. for handling single-node multi-GPU tensor parallel."""
  117. def __init__(self, result_handler: ResultHandler,
  118. worker_factory: Callable[[], Any]) -> None:
  119. self._task_queue = mp.Queue()
  120. self.result_queue = result_handler.result_queue
  121. self.tasks = result_handler.tasks
  122. self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
  123. target=_run_worker_process,
  124. name="AphroditeWorkerProcess",
  125. kwargs=dict(
  126. worker_factory=worker_factory,
  127. task_queue=self._task_queue,
  128. result_queue=self.result_queue,
  129. ),
  130. daemon=True)
  131. self.process.start()
  132. def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
  133. method: str, args, kwargs):
  134. task_id = uuid.uuid4()
  135. self.tasks[task_id] = future
  136. try:
  137. self._task_queue.put((task_id, method, args, kwargs))
  138. except SystemExit:
  139. raise
  140. except BaseException as e:
  141. del self.tasks[task_id]
  142. raise ChildProcessError("worker died") from e
  143. def execute_method(self, method: str, *args, **kwargs):
  144. future: ResultFuture = ResultFuture()
  145. self._enqueue_task(future, method, args, kwargs)
  146. return future
  147. async def execute_method_async(self, method: str, *args, **kwargs):
  148. future = asyncio.get_running_loop().create_future()
  149. self._enqueue_task(future, method, args, kwargs)
  150. return await future
  151. def terminate_worker(self):
  152. try:
  153. self._task_queue.put(_TERMINATE)
  154. except ValueError:
  155. self.process.kill()
  156. self._task_queue.close()
  157. def kill_worker(self):
  158. self._task_queue.close()
  159. self.process.kill()
  160. def _run_worker_process(
  161. worker_factory: Callable[[], Any],
  162. task_queue: Queue,
  163. result_queue: Queue,
  164. ) -> None:
  165. """Worker process event loop"""
  166. # Add process-specific prefix to stdout and stderr
  167. process_name = mp.current_process().name
  168. pid = os.getpid()
  169. _add_prefix(sys.stdout, process_name, pid)
  170. _add_prefix(sys.stderr, process_name, pid)
  171. # Initialize worker
  172. worker = worker_factory()
  173. del worker_factory
  174. # Accept tasks from the engine in task_queue
  175. # and return task output in result_queue
  176. logger.info("Worker ready; awaiting tasks")
  177. try:
  178. for items in iter(task_queue.get, _TERMINATE):
  179. output = None
  180. exception = None
  181. task_id, method, args, kwargs = items
  182. try:
  183. executor = getattr(worker, method)
  184. output = executor(*args, **kwargs)
  185. except SystemExit:
  186. raise
  187. except KeyboardInterrupt:
  188. break
  189. except BaseException as e:
  190. tb = traceback.format_exc()
  191. logger.error(f"Exception in worker {process_name} while "
  192. f"processing method {method}: {e}, {tb}")
  193. exception = e
  194. result_queue.put(
  195. Result(task_id=task_id, value=output, exception=exception))
  196. except KeyboardInterrupt:
  197. pass
  198. except Exception:
  199. logger.exception("Worker failed")
  200. logger.info("Worker exiting")
  201. def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
  202. """Prepend each output line with process-specific prefix"""
  203. prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
  204. file_write = file.write
  205. def write_with_prefix(s: str):
  206. if not s:
  207. return
  208. if file.start_new_line: # type: ignore[attr-defined]
  209. file_write(prefix)
  210. idx = 0
  211. while (next_idx := s.find('\n', idx)) != -1:
  212. next_idx += 1
  213. file_write(s[idx:next_idx])
  214. if next_idx == len(s):
  215. file.start_new_line = True # type: ignore[attr-defined]
  216. return
  217. file_write(prefix)
  218. idx = next_idx
  219. file_write(s[idx:])
  220. file.start_new_line = False # type: ignore[attr-defined]
  221. file.start_new_line = True # type: ignore[attr-defined]
  222. file.write = write_with_prefix # type: ignore[method-assign]