test_custom_all_reduce.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import os
  2. import random
  3. import pytest
  4. import ray
  5. import torch
  6. import torch.distributed as dist
  7. from aphrodite.distributed.communication_op import ( # noqa
  8. tensor_model_parallel_all_reduce)
  9. from aphrodite.distributed.parallel_state import (
  10. get_tensor_model_parallel_group, get_tp_group, graph_capture)
  11. from ..utils import (ensure_model_parallel_initialized,
  12. init_test_distributed_environment, multi_process_parallel)
  13. random.seed(42)
  14. test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
  15. for i, v in enumerate(test_sizes):
  16. test_sizes[i] -= v % 8
  17. @ray.remote(num_gpus=1, max_calls=1)
  18. def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
  19. del os.environ["CUDA_VISIBLE_DEVICES"]
  20. device = torch.device(f"cuda:{rank}")
  21. torch.cuda.set_device(device)
  22. init_test_distributed_environment(tp_size, pp_size, rank,
  23. distributed_init_port)
  24. ensure_model_parallel_initialized(tp_size, pp_size)
  25. group = get_tensor_model_parallel_group().device_group
  26. # A small all_reduce for warmup.
  27. # this is needed because device communicators might be created lazily
  28. # (e.g. NCCL). This will ensure that the communicator is initialized
  29. # before any communication happens, so that this group can be used for
  30. # graph capture immediately.
  31. data = torch.zeros(1)
  32. data = data.to(device=device)
  33. torch.distributed.all_reduce(data, group=group)
  34. torch.cuda.synchronize()
  35. del data
  36. # we use the first group to communicate once
  37. # and the second group to communicate twice
  38. # and so on
  39. # this is used to demonstrate that each group can
  40. # communicate independently
  41. num_communication = rank // tp_size + 1
  42. for sz in test_sizes:
  43. for dtype in [torch.float32, torch.float16, torch.bfloat16]:
  44. with graph_capture() as graph_capture_context:
  45. # use integers so result matches NCCL exactly
  46. inp1 = torch.randint(1,
  47. 16, (sz, ),
  48. dtype=dtype,
  49. device=torch.cuda.current_device())
  50. inp2 = torch.randint(1,
  51. 16, (sz, ),
  52. dtype=dtype,
  53. device=torch.cuda.current_device())
  54. torch.cuda.synchronize()
  55. graph = torch.cuda.CUDAGraph()
  56. with torch.cuda.graph(graph,
  57. stream=graph_capture_context.stream):
  58. for i in range(num_communication):
  59. out1 = tensor_model_parallel_all_reduce(inp1)
  60. # the input buffer is immediately modified to test
  61. # synchronization
  62. dist.all_reduce(inp1, group=group)
  63. out2 = tensor_model_parallel_all_reduce(inp2)
  64. dist.all_reduce(inp2, group=group)
  65. graph.replay()
  66. torch.testing.assert_close(out1, inp1)
  67. torch.testing.assert_close(out2, inp2)
  68. @ray.remote(num_gpus=1, max_calls=1)
  69. def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
  70. del os.environ["CUDA_VISIBLE_DEVICES"]
  71. device = torch.device(f"cuda:{rank}")
  72. torch.cuda.set_device(device)
  73. init_test_distributed_environment(tp_size, pp_size, rank,
  74. distributed_init_port)
  75. # we use the first group to communicate once
  76. # and the second group to communicate twice
  77. # and so on
  78. # this is used to demonstrate that each group can
  79. # communicate independently
  80. num_communication = rank // tp_size + 1
  81. sz = 1024
  82. fa = get_tp_group().ca_comm
  83. inp = torch.ones(sz, dtype=torch.float32, device=device)
  84. out = inp
  85. for _ in range(num_communication):
  86. out = fa.all_reduce_unreg(out)
  87. torch.testing.assert_close(out, inp * (tp_size**num_communication))
  88. inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
  89. out = inp
  90. for _ in range(num_communication):
  91. out = fa.all_reduce_unreg(out)
  92. torch.testing.assert_close(out, inp * (tp_size**num_communication))
  93. @pytest.mark.parametrize("tp_size", [2])
  94. @pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
  95. @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
  96. def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
  97. world_size = tp_size * pipeline_parallel_size
  98. if world_size > torch.cuda.device_count():
  99. pytest.skip("Not enough GPUs to run the test.")
  100. multi_process_parallel(tp_size, pipeline_parallel_size, test_target)