1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import multiprocessing
- import random
- import time
- from typing import List
- import numpy as np
- import torch.distributed as dist
- from aphrodite.common.utils import update_environment_variables
- from aphrodite.distributed.device_communicators.shm_broadcast import (
- MessageQueue)
- def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
- np.random.seed(seed)
- sizes = np.random.randint(1, 10_000, n)
- # on average, each array will have 5k elements
- # with int64, each array will have 40kb
- return [np.random.randint(1, 100, i) for i in sizes]
- def distributed_run(fn, world_size):
- number_of_processes = world_size
- processes = []
- for i in range(number_of_processes):
- env = {}
- env['RANK'] = str(i)
- env['LOCAL_RANK'] = str(i)
- env['WORLD_SIZE'] = str(number_of_processes)
- env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
- env['MASTER_ADDR'] = 'localhost'
- env['MASTER_PORT'] = '12345'
- p = multiprocessing.Process(target=fn, args=(env, ))
- processes.append(p)
- p.start()
- for p in processes:
- p.join()
- for p in processes:
- assert p.exitcode == 0
- def worker_fn_wrapper(fn):
- # `multiprocessing.Process` cannot accept environment variables directly
- # so we need to pass the environment variables as arguments
- # and update the environment variables in the function
- def wrapped_fn(env):
- update_environment_variables(env)
- dist.init_process_group(backend="gloo")
- fn()
- return wrapped_fn
- @worker_fn_wrapper
- def worker_fn():
- writer_rank = 2
- broadcaster = MessageQueue.create_from_process_group(
- dist.group.WORLD, 40 * 1024, 2, writer_rank)
- if dist.get_rank() == writer_rank:
- seed = random.randint(0, 1000)
- dist.broadcast_object_list([seed], writer_rank)
- else:
- recv = [None]
- dist.broadcast_object_list(recv, writer_rank)
- seed = recv[0] # type: ignore
- dist.barrier()
- # in case we find a race condition
- # print the seed so that we can reproduce the error
- print(f"Rank {dist.get_rank()} got seed {seed}")
- # test broadcasting with about 400MB of data
- N = 10_000
- if dist.get_rank() == writer_rank:
- arrs = get_arrays(N, seed)
- for x in arrs:
- broadcaster.broadcast_object(x)
- time.sleep(random.random() / 1000)
- else:
- arrs = get_arrays(N, seed)
- for x in arrs:
- y = broadcaster.broadcast_object(None)
- assert np.array_equal(x, y)
- time.sleep(random.random() / 1000)
- dist.barrier()
- def test_shm_broadcast():
- distributed_run(worker_fn, 4)
|