test_custom_all_reduce.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import random
  3. import pytest
  4. import ray
  5. import torch
  6. import torch.distributed as dist
  7. from aphrodite.common.test_utils import (init_test_distributed_environment,
  8. multi_process_tensor_parallel)
  9. from aphrodite.distributed import tensor_model_parallel_all_reduce
  10. from aphrodite.distributed.device_communicators import custom_all_reduce
  11. random.seed(42)
  12. test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
  13. for i, v in enumerate(test_sizes):
  14. test_sizes[i] -= v % 8
  15. @ray.remote(num_gpus=1, max_calls=1)
  16. def graph_allreduce(world_size, rank, distributed_init_port):
  17. del os.environ["CUDA_VISIBLE_DEVICES"]
  18. device = torch.device(f"cuda:{rank}")
  19. torch.cuda.set_device(device)
  20. init_test_distributed_environment(1, world_size, rank,
  21. distributed_init_port)
  22. custom_all_reduce.init_custom_all_reduce()
  23. for sz in test_sizes:
  24. for dtype in [torch.float32, torch.float16, torch.bfloat16]:
  25. with custom_all_reduce.capture():
  26. # use integers so result matches NCCL exactly
  27. inp1 = torch.randint(1,
  28. 16, (sz, ),
  29. dtype=dtype,
  30. device=torch.cuda.current_device())
  31. inp2 = torch.randint(1,
  32. 16, (sz, ),
  33. dtype=dtype,
  34. device=torch.cuda.current_device())
  35. torch.cuda.synchronize()
  36. graph = torch.cuda.CUDAGraph()
  37. with torch.cuda.graph(graph):
  38. out1 = tensor_model_parallel_all_reduce(inp1)
  39. # the input buffer is immediately modified to test
  40. # synchronization
  41. dist.all_reduce(inp1)
  42. out2 = tensor_model_parallel_all_reduce(inp2)
  43. dist.all_reduce(inp2)
  44. graph.replay()
  45. assert torch.allclose(out1, inp1)
  46. assert torch.allclose(out2, inp2)
  47. @ray.remote(num_gpus=1, max_calls=1)
  48. def eager_allreduce(world_size, rank, distributed_init_port):
  49. del os.environ["CUDA_VISIBLE_DEVICES"]
  50. device = torch.device(f"cuda:{rank}")
  51. torch.cuda.set_device(device)
  52. init_test_distributed_environment(1, world_size, rank,
  53. distributed_init_port)
  54. sz = 1024
  55. custom_all_reduce.init_custom_all_reduce()
  56. fa = custom_all_reduce.get_handle()
  57. inp = torch.ones(sz, dtype=torch.float32, device=device)
  58. out = fa.all_reduce_unreg(inp)
  59. assert torch.allclose(out, inp * world_size)
  60. inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
  61. out = fa.all_reduce_unreg(inp)
  62. assert torch.allclose(out, inp * world_size)
  63. @pytest.mark.skipif(torch.cuda.device_count() < 2,
  64. reason="Need at least 2 GPUs to run the test.")
  65. @pytest.mark.parametrize("tensor_parallel_size", [2])
  66. @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
  67. def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
  68. multi_process_tensor_parallel(tensor_parallel_size, test_target)
  69. if __name__ == "__main__":
  70. multi_process_tensor_parallel(2, graph_allreduce)