1
0

test_utils.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import ray
  2. from aphrodite.common.utils import get_open_port
  3. from aphrodite.distributed import (
  4. ensure_model_parallel_initialized,
  5. init_distributed_environment,
  6. )
  7. def init_test_distributed_environment(
  8. pipeline_parallel_size: int,
  9. tensor_parallel_size: int,
  10. rank: int,
  11. distributed_init_port: str,
  12. local_rank: int = -1,
  13. ) -> None:
  14. distributed_init_method = f"tcp://localhost:{distributed_init_port}"
  15. init_distributed_environment(
  16. world_size=pipeline_parallel_size * tensor_parallel_size,
  17. rank=rank,
  18. distributed_init_method=distributed_init_method,
  19. local_rank=local_rank)
  20. ensure_model_parallel_initialized(tensor_parallel_size,
  21. pipeline_parallel_size)
  22. def multi_process_tensor_parallel(
  23. tensor_parallel_size: int,
  24. test_target,
  25. ) -> None:
  26. # Using ray helps debugging the error when it failed
  27. # as compared to multiprocessing.
  28. ray.init()
  29. distributed_init_port = get_open_port()
  30. refs = []
  31. for rank in range(tensor_parallel_size):
  32. refs.append(
  33. test_target.remote(tensor_parallel_size, rank,
  34. distributed_init_port))
  35. ray.get(refs)
  36. ray.shutdown()