1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import ray
- from aphrodite.common.utils import get_open_port
- from aphrodite.distributed import (ensure_model_parallel_initialized,
- init_distributed_environment)
- def init_test_distributed_environment(
- pipeline_parallel_size: int,
- tensor_parallel_size: int,
- rank: int,
- distributed_init_port: str,
- local_rank: int = -1,
- ) -> None:
- distributed_init_method = f"tcp://localhost:{distributed_init_port}"
- init_distributed_environment(
- world_size=pipeline_parallel_size * tensor_parallel_size,
- rank=rank,
- distributed_init_method=distributed_init_method,
- local_rank=local_rank)
- ensure_model_parallel_initialized(tensor_parallel_size,
- pipeline_parallel_size)
- def multi_process_tensor_parallel(
- tensor_parallel_size: int,
- test_target,
- ) -> None:
- # Using ray helps debugging the error when it failed
- # as compared to multiprocessing.
- ray.init()
- distributed_init_port = get_open_port()
- refs = []
- for rank in range(tensor_parallel_size):
- refs.append(
- test_target.remote(tensor_parallel_size, rank,
- distributed_init_port))
- ray.get(refs)
- ray.shutdown()
|