1
0

test_utils.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import ray
  2. from aphrodite.common.utils import get_open_port
  3. from aphrodite.distributed import (ensure_model_parallel_initialized,
  4. init_distributed_environment)
  5. def init_test_distributed_environment(
  6. pipeline_parallel_size: int,
  7. tensor_parallel_size: int,
  8. rank: int,
  9. distributed_init_port: str,
  10. local_rank: int = -1,
  11. ) -> None:
  12. distributed_init_method = f"tcp://localhost:{distributed_init_port}"
  13. init_distributed_environment(
  14. world_size=pipeline_parallel_size * tensor_parallel_size,
  15. rank=rank,
  16. distributed_init_method=distributed_init_method,
  17. local_rank=local_rank)
  18. ensure_model_parallel_initialized(tensor_parallel_size,
  19. pipeline_parallel_size)
  20. def multi_process_tensor_parallel(
  21. tensor_parallel_size: int,
  22. test_target,
  23. ) -> None:
  24. # Using ray helps debugging the error when it failed
  25. # as compared to multiprocessing.
  26. ray.init()
  27. distributed_init_port = get_open_port()
  28. refs = []
  29. for rank in range(tensor_parallel_size):
  30. refs.append(
  31. test_target.remote(tensor_parallel_size, rank,
  32. distributed_init_port))
  33. ray.get(refs)
  34. ray.shutdown()