test_utils.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import ray
  2. from aphrodite.common.config import ParallelConfig
  3. from aphrodite.common.utils import get_open_port
  4. from aphrodite.task_handler.worker import init_distributed_environment
  5. def init_test_distributed_environment(
  6. pipeline_parallel_size: int,
  7. tensor_parallel_size: int,
  8. rank: int,
  9. distributed_init_port: str,
  10. ) -> None:
  11. parallel_config = ParallelConfig(pipeline_parallel_size,
  12. tensor_parallel_size,
  13. worker_use_ray=True)
  14. distributed_init_method = f"tcp://localhost:{distributed_init_port}"
  15. init_distributed_environment(parallel_config, rank,
  16. distributed_init_method)
  17. def multi_process_tensor_parallel(
  18. tensor_parallel_size: int,
  19. test_target,
  20. ) -> None:
  21. # Using ray helps debugging the error when it failed
  22. # as compared to multiprocessing.
  23. ray.init()
  24. distributed_init_port = get_open_port()
  25. refs = []
  26. for rank in range(tensor_parallel_size):
  27. refs.append(
  28. test_target.remote(tensor_parallel_size, rank,
  29. distributed_init_port))
  30. ray.get(refs)
  31. ray.shutdown()