ray_gpu_executor.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  6. from loguru import logger
  7. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  8. from aphrodite.common.utils import (_run_task_with_lock,
  9. get_aphrodite_instance_id,
  10. get_distributed_init_method, get_ip,
  11. get_open_port, make_async)
  12. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  13. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  14. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  15. if ray is not None:
  16. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  17. if TYPE_CHECKING:
  18. from ray.util.placement_group import PlacementGroup
  19. # If the env var is set, it uses the Ray's compiled DAG API
  20. # which optimizes the control plane overhead.
  21. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  22. APHRODITE_USE_RAY_COMPILED_DAG = bool(
  23. os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  24. APHRODITE_TRACE_FUNCTION = int(os.getenv("APHRODITE_TRACE_FUNCTION", 0))
  25. APHRODITE_USE_RAY_SPMD_WORKER = bool(
  26. os.getenv("APHRODITE_USE_RAY_SPMD_WORKER", 0))
  27. APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = bool(
  28. int(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", 1)))
  29. class RayGPUExecutor(DistributedGPUExecutor):
  30. uses_ray: bool = True
  31. def _init_executor(self) -> None:
  32. self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
  33. # If the env var is set, it uses the Ray's compiled DAG API
  34. # which optimizes the control plane overhead.
  35. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  36. # Currently, this requires USE_RAY_SPMD_WORKER=True.
  37. self.use_ray_compiled_dag = APHRODITE_USE_RAY_COMPILED_DAG
  38. # If the env var is set, then we do not distinguish between the
  39. # "driver worker" vs other workers. Also, the rank 0 worker will
  40. # be executed in a remote Ray worker. Currently this requires
  41. # USE_RAY_COMPILED_DAG=True.
  42. self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
  43. if self.use_ray_compiled_dag:
  44. assert self.use_ray_spmd_worker, (
  45. "APHRODITE_USE_RAY_COMPILED_DAG=1 requires "
  46. "APHRODITE_USE_RAY_SPMD_WORKER=1")
  47. if self.use_ray_spmd_worker:
  48. # TODO: Support SPMD worker for non-DAG Ray executor.
  49. assert self.use_ray_compiled_dag, (
  50. "APHRODITE_USE_RAY_SPMD_WORKER=1 requires "
  51. "APHRODITE_USE_RAY_COMPILED_DAG=1")
  52. assert self.uses_ray
  53. placement_group = self.parallel_config.placement_group
  54. # Disable Ray usage stats collection.
  55. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  56. if ray_usage != "1":
  57. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  58. # Create the parallel GPU workers.
  59. self._init_workers_ray(placement_group)
  60. def _configure_ray_workers_use_nsight(self,
  61. ray_remote_kwargs) -> Dict[str, Any]:
  62. # If nsight profiling is enabled, we need to set the profiling
  63. # configuration for the ray workers as runtime env.
  64. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  65. runtime_env.update({
  66. "nsight": {
  67. "t": "cuda,cudnn,cublas",
  68. "o": "'worker_process_%p'",
  69. "cuda-graph-trace": "node",
  70. }
  71. })
  72. return ray_remote_kwargs
  73. def _get_worker_wrapper_args(self) -> Dict[str, Any]:
  74. if self.speculative_config is not None:
  75. worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
  76. worker_class_name = "create_spec_worker"
  77. else:
  78. worker_module_name = "aphrodite.task_handler.worker"
  79. worker_class_name = "Worker"
  80. return dict(
  81. worker_module_name=worker_module_name,
  82. worker_class_name=worker_class_name,
  83. trust_remote_code=self.model_config.trust_remote_code,
  84. )
  85. def _init_workers_ray(self, placement_group: "PlacementGroup",
  86. **ray_remote_kwargs):
  87. if (self.parallel_config.tensor_parallel_size == 1
  88. and self.parallel_config.pipeline_parallel_size == 1):
  89. # For single GPU case, we use a ray worker with constrained memory.
  90. num_gpus = self.cache_config.gpu_memory_utilization
  91. else:
  92. # Otherwise, the ray workers are allocated with a full GPU.
  93. num_gpus = 1
  94. # The driver dummy worker does not actually use any resources.
  95. # It holds the resource for the driver worker.
  96. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  97. # The remaining workers are the actual ray actors.
  98. self.workers: List[RayWorkerWrapper] = []
  99. # Used in ray compiled DAG: indexed first by PP rank,
  100. # and then TP rank. In other words, the inner list is
  101. # the TP group of workers for a PP rank.
  102. self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
  103. if self.parallel_config.ray_workers_use_nsight:
  104. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  105. ray_remote_kwargs)
  106. logger.info(f"use_ray_spmd_worker: {self.use_ray_spmd_worker}")
  107. # Create the workers.
  108. driver_ip = get_ip()
  109. logger.info(f"driver_ip: {driver_ip}")
  110. worker_wrapper_kwargs = self._get_worker_wrapper_args()
  111. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  112. if not bundle.get("GPU", 0):
  113. continue
  114. scheduling_strategy = PlacementGroupSchedulingStrategy(
  115. placement_group=placement_group,
  116. placement_group_capture_child_tasks=True,
  117. placement_group_bundle_index=bundle_id,
  118. )
  119. worker = ray.remote(
  120. num_cpus=0,
  121. num_gpus=num_gpus,
  122. scheduling_strategy=scheduling_strategy,
  123. **ray_remote_kwargs,
  124. )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
  125. if self.use_ray_spmd_worker:
  126. self.workers.append(worker)
  127. else:
  128. worker_ip = ray.get(worker.get_node_ip.remote())
  129. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  130. # If the worker is on the same node as the driver, we use it
  131. # as the resource holder for the driver process.
  132. self.driver_dummy_worker = worker
  133. self.driver_worker = RayWorkerWrapper(
  134. **worker_wrapper_kwargs)
  135. else:
  136. # Else, added to the list of workers.
  137. self.workers.append(worker)
  138. logger.debug(f"workers: {self.workers}")
  139. logger.debug(f"driver_dummy_worker: {self.driver_dummy_worker}")
  140. if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
  141. raise ValueError(
  142. "Ray does not allocate any GPUs on the driver node. Consider "
  143. "adjusting the Ray placement group or running the driver on a "
  144. "GPU node.")
  145. worker_ips = [
  146. ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
  147. for worker in self.workers
  148. ]
  149. ip_counts: Dict[str, int] = {}
  150. for ip in worker_ips:
  151. ip_counts[ip] = ip_counts.get(ip, 0) + 1
  152. def sort_by_driver_then_worker_ip(worker):
  153. """
  154. Sort the workers based on 3 properties:
  155. 1. If the worker is on the same node as the driver (vllm engine),
  156. it should be placed first.
  157. 2. Then, if the worker is on a node with fewer workers, it should
  158. be placed first.
  159. 3. Finally, if the work is on a node with smaller IP address, it
  160. should be placed first.
  161. """
  162. ip = ray.get(worker.get_node_ip.remote())
  163. return (ip != driver_ip, ip_counts[ip], ip)
  164. # After sorting, the workers on the same node will be
  165. # close to each other, and the workers on the driver
  166. # node will be placed first.
  167. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
  168. # Get the set of GPU IDs used on each node.
  169. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  170. use_dummy_driver=True)
  171. node_workers = defaultdict(list) # node id -> list of worker ranks
  172. node_gpus = defaultdict(list) # node id -> list of gpu ids
  173. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  174. node_workers[node_id].append(i)
  175. # `gpu_ids` can be a list of strings or integers.
  176. # convert them to integers for consistency.
  177. # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
  178. # string sorting is not sufficient.
  179. gpu_ids = [int(x) for x in gpu_ids]
  180. node_gpus[node_id].extend(gpu_ids)
  181. for node_id, gpu_ids in node_gpus.items():
  182. node_gpus[node_id] = sorted(gpu_ids)
  183. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  184. # Set environment variables for the driver and workers.
  185. all_args_to_update_environment_variables = [({
  186. "CUDA_VISIBLE_DEVICES":
  187. ",".join(map(str, node_gpus[node_id])),
  188. "APHRODITE_INSTANCE_ID":
  189. APHRODITE_INSTANCE_ID,
  190. "APHRODITE_TRACE_FUNCTION":
  191. str(APHRODITE_TRACE_FUNCTION),
  192. }, ) for (node_id, _) in worker_node_and_gpu_ids]
  193. self._run_workers("update_environment_variables",
  194. all_args=all_args_to_update_environment_variables)
  195. if len(node_gpus) == 1:
  196. # in single node case, we don't need to get the IP address.
  197. # the loopback address is sufficient
  198. # NOTE: a node may have several IP addresses, one for each
  199. # network interface. `get_ip()` might return any of them,
  200. # while they might not work for communication inside the node
  201. # if the network setup is complicated. Using the loopback address
  202. # solves this issue, as it always works for communication inside
  203. # the node.
  204. driver_ip = "127.0.0.1"
  205. distributed_init_method = get_distributed_init_method(
  206. driver_ip, get_open_port())
  207. # Initialize the actual workers inside worker wrapper.
  208. init_worker_all_kwargs = [
  209. self._get_worker_kwargs(
  210. local_rank=node_workers[node_id].index(rank),
  211. rank=rank,
  212. distributed_init_method=distributed_init_method,
  213. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  214. ]
  215. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  216. self._run_workers("init_device")
  217. self._run_workers("load_model",
  218. max_concurrent_workers=self.parallel_config.
  219. max_parallel_loading_workers)
  220. if self.use_ray_spmd_worker:
  221. for pp_rank in range(self.parallel_config.pipeline_parallel_size):
  222. self.pp_tp_workers.append([])
  223. for tp_rank in range(
  224. self.parallel_config.tensor_parallel_size):
  225. # PP=2, TP=4
  226. # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
  227. rank = (pp_rank * self.parallel_config.tensor_parallel_size
  228. ) + tp_rank
  229. assert len(self.pp_tp_workers[pp_rank]) == tp_rank
  230. assert pp_rank < len(self.pp_tp_workers)
  231. self.pp_tp_workers[pp_rank].append(self.workers[rank])
  232. # This is the list of workers that are rank 0 of each TP group EXCEPT
  233. # global rank 0. These are the workers that will broadcast to the
  234. # rest of the workers.
  235. self.tp_driver_workers: List[RayWorkerWrapper] = []
  236. # This is the list of workers that are not drivers and not the first
  237. # worker in a TP group. These are the workers that will be
  238. # broadcasted to.
  239. self.non_driver_workers: List[RayWorkerWrapper] = []
  240. # Enforce rank order for correct rank to return final output.
  241. for index, worker in enumerate(self.workers):
  242. # The driver worker is rank 0 and not in self.workers.
  243. rank = index + 1
  244. if rank % self.parallel_config.tensor_parallel_size == 0:
  245. self.tp_driver_workers.append(worker)
  246. else:
  247. self.non_driver_workers.append(worker)
  248. def _driver_execute_model(
  249. self, execute_model_req: Optional[ExecuteModelRequest]
  250. ) -> Optional[List[SamplerOutput]]:
  251. """Run execute_model in the driver worker.
  252. Passing None will cause the driver to stop the model execution
  253. loop running in each of the remote workers.
  254. """
  255. assert not self.use_ray_spmd_worker, (
  256. "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
  257. return self.driver_worker.execute_method("execute_model",
  258. execute_model_req)
  259. def execute_model(
  260. self,
  261. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  262. if not self.use_ray_spmd_worker:
  263. return super().execute_model(execute_model_req)
  264. if self.forward_dag is None:
  265. self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
  266. outputs = ray.get(self.forward_dag.execute(execute_model_req))
  267. return outputs[0]
  268. def _run_workers(
  269. self,
  270. method: str,
  271. *args,
  272. async_run_tensor_parallel_workers_only: bool = False,
  273. all_args: Optional[List[Tuple[Any, ...]]] = None,
  274. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  275. use_dummy_driver: bool = False,
  276. max_concurrent_workers: Optional[int] = None,
  277. **kwargs,
  278. ) -> Any:
  279. """Runs the given method on all workers. Can be used in the following
  280. ways:
  281. Args:
  282. - async_run_tensor_parallel_workers_only: If True the method will be
  283. run only in the remote TP workers, not the driver worker.
  284. It will also be run asynchronously and return a list of futures
  285. rather than blocking on the results.
  286. - args/kwargs: All workers share the same args/kwargs
  287. - all_args/all_kwargs: args/kwargs for each worker are specified
  288. individually
  289. """
  290. if self.use_ray_spmd_worker:
  291. assert not async_run_tensor_parallel_workers_only, (
  292. "async_run_tensor_parallel_workers_only is not supported for "
  293. "spmd mode.")
  294. if max_concurrent_workers:
  295. raise NotImplementedError(
  296. "max_concurrent_workers is not supported yet.")
  297. count = len(self.workers) if not \
  298. async_run_tensor_parallel_workers_only \
  299. else len(self.non_driver_workers)
  300. # If using SPMD worker, all workers are the same, so we should execute
  301. # the args on all workers. Otherwise, we skip the first worker's args
  302. # because those args will go to the driver worker.
  303. first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
  304. all_worker_args = repeat(args, count) if all_args is None \
  305. else islice(all_args, first_worker_args_index, None)
  306. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  307. else islice(all_kwargs, first_worker_args_index, None)
  308. # Start the ray workers first.
  309. ray_workers = self.workers
  310. if async_run_tensor_parallel_workers_only:
  311. ray_workers = self.non_driver_workers
  312. ray_worker_outputs = [
  313. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  314. for (worker, worker_args, worker_kwargs
  315. ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
  316. ]
  317. if async_run_tensor_parallel_workers_only:
  318. # Just return futures
  319. return ray_worker_outputs
  320. driver_worker_output = []
  321. # In SPMD mode, the driver worker is the same as any other worker,
  322. # so we only explicitly execute on the driver worker if using a
  323. # non-SPMD worker class.
  324. if not self.use_ray_spmd_worker:
  325. driver_args = args if all_args is None else all_args[0]
  326. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  327. # Start the driver worker after all the ray workers.
  328. if not use_dummy_driver:
  329. driver_worker_output = [
  330. self.driver_worker.execute_method(method, *driver_args,
  331. **driver_kwargs)
  332. ]
  333. else:
  334. assert self.driver_dummy_worker is not None
  335. driver_worker_output = [
  336. ray.get(
  337. self.driver_dummy_worker.execute_method.remote(
  338. method, *driver_args, **driver_kwargs))
  339. ]
  340. # Get the results of the ray workers.
  341. if self.workers:
  342. ray_worker_outputs = ray.get(ray_worker_outputs)
  343. return driver_worker_output + ray_worker_outputs
  344. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  345. """Wait for futures returned from _run_workers() with
  346. async_run_remote_workers_only to complete."""
  347. ray.get(parallel_worker_tasks)
  348. def _compiled_ray_dag(self, enable_asyncio: bool):
  349. import pkg_resources
  350. from packaging import version
  351. required_version = version.parse("2.32")
  352. current_version = version.parse(
  353. pkg_resources.get_distribution("ray").version)
  354. if current_version < required_version:
  355. raise ValueError(f"Ray version {required_version} or greater is "
  356. f"required, but found {current_version}")
  357. assert self.parallel_config.use_ray
  358. from ray.dag import InputNode, MultiOutputNode
  359. from ray.experimental.channel.torch_tensor_type import TorchTensorType
  360. logger.info(f"APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = "
  361. f"{APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL}")
  362. with InputNode() as input_data:
  363. # Example DAG: PP=2, TP=4
  364. # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
  365. # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
  366. # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
  367. # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
  368. # All workers in the first TP group will take in the
  369. # ExecuteModelRequest as input.
  370. outputs = [input_data for _ in self.pp_tp_workers[0]]
  371. for pp_rank, tp_group in enumerate(self.pp_tp_workers):
  372. # Each PP worker takes in the output of the previous PP worker,
  373. # and the TP group executes in SPMD fashion.
  374. outputs = [
  375. worker.execute_model_spmd.
  376. bind( # type: ignore[attr-defined]
  377. outputs[i]) for i, worker in enumerate(tp_group)
  378. ]
  379. last_pp_rank = len(self.pp_tp_workers) - 1
  380. if pp_rank < last_pp_rank:
  381. # Specify how intermediate tensors should be passed
  382. # between pp stages, no need to specify for the last
  383. # pp stage.
  384. transport = "nccl" \
  385. if APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
  386. else "auto"
  387. outputs = [
  388. output.with_type_hint(
  389. TorchTensorType(transport=transport))
  390. for output in outputs
  391. ]
  392. forward_dag = MultiOutputNode(outputs)
  393. return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
  394. def __del__(self):
  395. if self.forward_dag is not None:
  396. self.forward_dag.teardown()
  397. import ray
  398. for worker in self.workers:
  399. ray.kill(worker)
  400. class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
  401. def __init__(self, *args, **kwargs):
  402. super().__init__(*args, **kwargs)
  403. self.pp_locks: Optional[List[asyncio.Lock]] = None
  404. self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
  405. if not self.use_ray_compiled_dag:
  406. self.driver_exec_method = make_async(
  407. self.driver_worker.execute_method)
  408. async def execute_model_async(
  409. self,
  410. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  411. if not self.use_ray_spmd_worker:
  412. return await super().execute_model_async(execute_model_req)
  413. if self.forward_dag is None:
  414. self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
  415. dag_future = await self.forward_dag.execute_async(execute_model_req)
  416. outputs = await dag_future
  417. return outputs[0]
  418. async def _driver_execute_model_async(
  419. self,
  420. execute_model_req: Optional[ExecuteModelRequest] = None
  421. ) -> List[SamplerOutput]:
  422. assert not self.use_ray_spmd_worker, (
  423. "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
  424. if not self.tp_driver_workers:
  425. return await self.driver_exec_method("execute_model",
  426. execute_model_req)
  427. if self.pp_locks is None:
  428. # This locks each pipeline parallel stage so multiple virtual
  429. # engines can't execute on the same stage at the same time
  430. # We create the locks here to avoid creating them in the constructor
  431. # which uses a different asyncio loop.
  432. self.pp_locks = [
  433. asyncio.Lock()
  434. for _ in range(self.parallel_config.pipeline_parallel_size)
  435. ]
  436. tasks = [
  437. asyncio.create_task(
  438. _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
  439. "execute_model", execute_model_req))
  440. ]
  441. for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
  442. start=1):
  443. tasks.append(
  444. asyncio.create_task(
  445. _run_task_with_lock(driver_worker.execute_method.remote,
  446. self.pp_locks[pp_rank],
  447. "execute_model", execute_model_req)))
  448. results = await asyncio.gather(*tasks)
  449. # Only the last PP stage has the final results.
  450. return results[-1]
  451. async def _start_worker_execution_loop(self):
  452. assert not self.use_ray_spmd_worker, (
  453. "worker loop is disabled for APHRODITE_USE_RAY_SPMD_WORKER=1")
  454. coros = [
  455. worker.execute_method.remote("start_worker_execution_loop")
  456. for worker in self.non_driver_workers
  457. ]
  458. return await asyncio.gather(*coros)
  459. def __del__(self):
  460. if self.forward_dag is not None:
  461. self.forward_dag.teardown()
  462. import ray
  463. for worker in self.workers:
  464. ray.kill(worker)