test_shm_broadcast.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import multiprocessing
  2. import random
  3. import time
  4. from typing import List
  5. import numpy as np
  6. import torch.distributed as dist
  7. from aphrodite.common.utils import update_environment_variables
  8. from aphrodite.distributed.device_communicators.shm_broadcast import (
  9. MessageQueue)
  10. def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
  11. np.random.seed(seed)
  12. sizes = np.random.randint(1, 10_000, n)
  13. # on average, each array will have 5k elements
  14. # with int64, each array will have 40kb
  15. return [np.random.randint(1, 100, i) for i in sizes]
  16. def distributed_run(fn, world_size):
  17. number_of_processes = world_size
  18. processes = []
  19. for i in range(number_of_processes):
  20. env = {}
  21. env['RANK'] = str(i)
  22. env['LOCAL_RANK'] = str(i)
  23. env['WORLD_SIZE'] = str(number_of_processes)
  24. env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
  25. env['MASTER_ADDR'] = 'localhost'
  26. env['MASTER_PORT'] = '12345'
  27. p = multiprocessing.Process(target=fn, args=(env, ))
  28. processes.append(p)
  29. p.start()
  30. for p in processes:
  31. p.join()
  32. for p in processes:
  33. assert p.exitcode == 0
  34. def worker_fn_wrapper(fn):
  35. # `multiprocessing.Process` cannot accept environment variables directly
  36. # so we need to pass the environment variables as arguments
  37. # and update the environment variables in the function
  38. def wrapped_fn(env):
  39. update_environment_variables(env)
  40. dist.init_process_group(backend="gloo")
  41. fn()
  42. return wrapped_fn
  43. @worker_fn_wrapper
  44. def worker_fn():
  45. writer_rank = 2
  46. broadcaster = MessageQueue.create_from_process_group(
  47. dist.group.WORLD, 40 * 1024, 2, writer_rank)
  48. if dist.get_rank() == writer_rank:
  49. seed = random.randint(0, 1000)
  50. dist.broadcast_object_list([seed], writer_rank)
  51. else:
  52. recv = [None]
  53. dist.broadcast_object_list(recv, writer_rank)
  54. seed = recv[0] # type: ignore
  55. dist.barrier()
  56. # in case we find a race condition
  57. # print the seed so that we can reproduce the error
  58. print(f"Rank {dist.get_rank()} got seed {seed}")
  59. # test broadcasting with about 400MB of data
  60. N = 10_000
  61. if dist.get_rank() == writer_rank:
  62. arrs = get_arrays(N, seed)
  63. for x in arrs:
  64. broadcaster.broadcast_object(x)
  65. time.sleep(random.random() / 1000)
  66. else:
  67. arrs = get_arrays(N, seed)
  68. for x in arrs:
  69. y = broadcaster.broadcast_object(None)
  70. assert np.array_equal(x, y)
  71. time.sleep(random.random() / 1000)
  72. dist.barrier()
  73. def test_shm_broadcast():
  74. distributed_run(worker_fn, 4)