test_shm_broadcast.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import multiprocessing
  2. import random
  3. import time
  4. import torch.distributed as dist
  5. from aphrodite.common.utils import update_environment_variables
  6. from aphrodite.distributed.device_communicators.shm_broadcast import (
  7. ShmRingBuffer, ShmRingBufferIO)
  8. def distributed_run(fn, world_size):
  9. number_of_processes = world_size
  10. processes = []
  11. for i in range(number_of_processes):
  12. env = {}
  13. env['RANK'] = str(i)
  14. env['LOCAL_RANK'] = str(i)
  15. env['WORLD_SIZE'] = str(number_of_processes)
  16. env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
  17. env['MASTER_ADDR'] = 'localhost'
  18. env['MASTER_PORT'] = '12345'
  19. p = multiprocessing.Process(target=fn, args=(env, ))
  20. processes.append(p)
  21. p.start()
  22. for p in processes:
  23. p.join()
  24. for p in processes:
  25. assert p.exitcode == 0
  26. def worker_fn_wrapper(fn):
  27. # `multiprocessing.Process` cannot accept environment variables directly
  28. # so we need to pass the environment variables as arguments
  29. # and update the environment variables in the function
  30. def wrapped_fn(env):
  31. update_environment_variables(env)
  32. dist.init_process_group(backend="gloo")
  33. fn()
  34. return wrapped_fn
  35. @worker_fn_wrapper
  36. def worker_fn():
  37. writer_rank = 2
  38. broadcaster = ShmRingBufferIO.create_from_process_group(
  39. dist.group.WORLD, 1024, 2, writer_rank)
  40. if dist.get_rank() == writer_rank:
  41. time.sleep(random.random())
  42. broadcaster.broadcast_object(0)
  43. time.sleep(random.random())
  44. broadcaster.broadcast_object({})
  45. time.sleep(random.random())
  46. broadcaster.broadcast_object([])
  47. else:
  48. time.sleep(random.random())
  49. a = broadcaster.broadcast_object(None)
  50. time.sleep(random.random())
  51. b = broadcaster.broadcast_object(None)
  52. time.sleep(random.random())
  53. c = broadcaster.broadcast_object(None)
  54. assert a == 0
  55. assert b == {}
  56. assert c == []
  57. dist.barrier()
  58. def test_shm_broadcast():
  59. distributed_run(worker_fn, 4)
  60. def test_singe_process():
  61. buffer = ShmRingBuffer(1, 1024, 4)
  62. reader = ShmRingBufferIO(buffer, reader_rank=0)
  63. writer = ShmRingBufferIO(buffer, reader_rank=-1)
  64. writer.enqueue([0])
  65. writer.enqueue([1])
  66. assert reader.dequeue() == [0]
  67. assert reader.dequeue() == [1]