multiproc_worker_utils.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import asyncio
  2. import multiprocessing
  3. import os
  4. import sys
  5. import threading
  6. import traceback
  7. import uuid
  8. from dataclasses import dataclass
  9. from multiprocessing import Queue
  10. from multiprocessing.connection import wait
  11. from multiprocessing.process import BaseProcess
  12. from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
  13. TypeVar, Union)
  14. from loguru import logger
  15. import aphrodite.common.envs as envs
  16. T = TypeVar('T')
  17. _TERMINATE = "TERMINATE" # sentinel
  18. # ANSI color codes
  19. CYAN = '\033[1;36m'
  20. RESET = '\033[0;0m'
  21. JOIN_TIMEOUT_S = 2
  22. # Use dedicated multiprocess context for workers.
  23. # Both spawn and fork work
  24. mp_method = envs.APHRODITE_WORKER_MULTIPROC_METHOD
  25. mp = multiprocessing.get_context(mp_method)
  26. @dataclass
  27. class Result(Generic[T]):
  28. """Result of task dispatched to worker"""
  29. task_id: uuid.UUID
  30. value: Optional[T] = None
  31. exception: Optional[BaseException] = None
  32. class ResultFuture(threading.Event, Generic[T]):
  33. """Synchronous future for non-async case"""
  34. def __init__(self):
  35. super().__init__()
  36. self.result: Optional[Result[T]] = None
  37. def set_result(self, result: Result[T]):
  38. self.result = result
  39. self.set()
  40. def get(self) -> T:
  41. self.wait()
  42. assert self.result is not None
  43. if self.result.exception is not None:
  44. raise self.result.exception
  45. return self.result.value # type: ignore[return-value]
  46. def _set_future_result(future: Union[ResultFuture, asyncio.Future],
  47. result: Result):
  48. if isinstance(future, ResultFuture):
  49. future.set_result(result)
  50. return
  51. loop = future.get_loop()
  52. if not loop.is_closed():
  53. if result.exception is not None:
  54. loop.call_soon_threadsafe(future.set_exception, result.exception)
  55. else:
  56. loop.call_soon_threadsafe(future.set_result, result.value)
  57. class ResultHandler(threading.Thread):
  58. """Handle results from all workers (in background thread)"""
  59. def __init__(self) -> None:
  60. super().__init__(daemon=True)
  61. self.result_queue = mp.Queue()
  62. self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
  63. def run(self):
  64. for result in iter(self.result_queue.get, _TERMINATE):
  65. future = self.tasks.pop(result.task_id)
  66. _set_future_result(future, result)
  67. # Ensure that all waiters will receive an exception
  68. for task_id, future in self.tasks.items():
  69. _set_future_result(
  70. future,
  71. Result(task_id=task_id,
  72. exception=ChildProcessError("worker died")))
  73. def close(self):
  74. self.result_queue.put(_TERMINATE)
  75. class WorkerMonitor(threading.Thread):
  76. """Monitor worker status (in background thread)"""
  77. def __init__(self, workers: List['ProcessWorkerWrapper'],
  78. result_handler: ResultHandler):
  79. super().__init__(daemon=True)
  80. self.workers = workers
  81. self.result_handler = result_handler
  82. self._close = False
  83. def run(self) -> None:
  84. # Blocks until any worker exits
  85. dead_sentinels = wait([w.process.sentinel for w in self.workers])
  86. if not self._close:
  87. self._close = True
  88. # Kill / cleanup all workers
  89. for worker in self.workers:
  90. process = worker.process
  91. if process.sentinel in dead_sentinels:
  92. process.join(JOIN_TIMEOUT_S)
  93. if process.exitcode is not None and process.exitcode != 0:
  94. logger.error(f"Worker {process.name} pid {process.pid} "
  95. f"died, exit code: {process.exitcode}")
  96. # Cleanup any remaining workers
  97. logger.info("Killing local Aphrodite worker processes")
  98. for worker in self.workers:
  99. worker.kill_worker()
  100. # Must be done after worker task queues are all closed
  101. self.result_handler.close()
  102. for worker in self.workers:
  103. worker.process.join(JOIN_TIMEOUT_S)
  104. def close(self):
  105. if self._close:
  106. return
  107. self._close = True
  108. logger.info("Terminating local Aphrodite worker processes")
  109. for worker in self.workers:
  110. worker.terminate_worker()
  111. # Must be done after worker task queues are all closed
  112. self.result_handler.close()
  113. class ProcessWorkerWrapper:
  114. """Local process wrapper for aphrodite.task_handler.Worker,
  115. for handling single-node multi-GPU tensor parallel."""
  116. def __init__(self, result_handler: ResultHandler,
  117. worker_factory: Callable[[], Any]) -> None:
  118. self._task_queue = mp.Queue()
  119. self.result_queue = result_handler.result_queue
  120. self.tasks = result_handler.tasks
  121. self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
  122. target=_run_worker_process,
  123. name="AphroditeWorkerProcess",
  124. kwargs=dict(
  125. worker_factory=worker_factory,
  126. task_queue=self._task_queue,
  127. result_queue=self.result_queue,
  128. ),
  129. daemon=True)
  130. self.process.start()
  131. def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
  132. method: str, args, kwargs):
  133. task_id = uuid.uuid4()
  134. self.tasks[task_id] = future
  135. try:
  136. self._task_queue.put((task_id, method, args, kwargs))
  137. except BaseException as e:
  138. del self.tasks[task_id]
  139. raise ChildProcessError("worker died") from e
  140. def execute_method(self, method: str, *args, **kwargs):
  141. future: ResultFuture = ResultFuture()
  142. self._enqueue_task(future, method, args, kwargs)
  143. return future
  144. async def execute_method_async(self, method: str, *args, **kwargs):
  145. future = asyncio.get_running_loop().create_future()
  146. self._enqueue_task(future, method, args, kwargs)
  147. return await future
  148. def terminate_worker(self):
  149. try:
  150. self._task_queue.put(_TERMINATE)
  151. except ValueError:
  152. self.process.kill()
  153. self._task_queue.close()
  154. def kill_worker(self):
  155. self._task_queue.close()
  156. self.process.kill()
  157. def _run_worker_process(
  158. worker_factory: Callable[[], Any],
  159. task_queue: Queue,
  160. result_queue: Queue,
  161. ) -> None:
  162. """Worker process event loop"""
  163. # Add process-specific prefix to stdout and stderr
  164. process_name = mp.current_process().name
  165. pid = os.getpid()
  166. _add_prefix(sys.stdout, process_name, pid)
  167. _add_prefix(sys.stderr, process_name, pid)
  168. # Initialize worker
  169. worker = worker_factory()
  170. del worker_factory
  171. # Accept tasks from the engine in task_queue
  172. # and return task output in result_queue
  173. logger.info("Worker ready; awaiting tasks")
  174. try:
  175. for items in iter(task_queue.get, _TERMINATE):
  176. output = None
  177. exception = None
  178. task_id, method, args, kwargs = items
  179. try:
  180. executor = getattr(worker, method)
  181. output = executor(*args, **kwargs)
  182. except BaseException as e:
  183. tb = traceback.format_exc()
  184. logger.error(f"Exception in worker {process_name} while "
  185. f"processing method {method}: {e}, {tb}")
  186. exception = e
  187. result_queue.put(
  188. Result(task_id=task_id, value=output, exception=exception))
  189. except KeyboardInterrupt:
  190. pass
  191. except Exception:
  192. logger.exception("Worker failed")
  193. logger.info("Worker exiting")
  194. def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
  195. """Prepend each output line with process-specific prefix"""
  196. prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
  197. file_write = file.write
  198. def write_with_prefix(s: str):
  199. if not s:
  200. return
  201. if file.start_new_line: # type: ignore[attr-defined]
  202. file_write(prefix)
  203. idx = 0
  204. while (next_idx := s.find('\n', idx)) != -1:
  205. next_idx += 1
  206. file_write(s[idx:next_idx])
  207. if next_idx == len(s):
  208. file.start_new_line = True # type: ignore[attr-defined]
  209. return
  210. file_write(prefix)
  211. idx = next_idx
  212. file_write(s[idx:])
  213. file.start_new_line = False # type: ignore[attr-defined]
  214. file.start_new_line = True # type: ignore[attr-defined]
  215. file.write = write_with_prefix # type: ignore[method-assign]