1
0

test_comm_ops.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """Test the communication operators.
  2. Run `pytest tests/distributed/test_comm_ops.py`.
  3. """
  4. import os
  5. import pytest
  6. import ray
  7. import torch
  8. from aphrodite.common.test_utils import (init_test_distributed_environment,
  9. multi_process_tensor_parallel)
  10. from aphrodite.distributed import (broadcast_tensor_dict,
  11. tensor_model_parallel_all_gather,
  12. tensor_model_parallel_all_reduce)
  13. @ray.remote(num_gpus=1, max_calls=1)
  14. def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
  15. distributed_init_port: str):
  16. # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
  17. # so that each worker can see all the GPUs
  18. # they will be able to set the device to the correct GPU
  19. del os.environ["CUDA_VISIBLE_DEVICES"]
  20. device = torch.device(f"cuda:{rank}")
  21. torch.cuda.set_device(device)
  22. init_test_distributed_environment(1, tensor_parallel_size, rank,
  23. distributed_init_port)
  24. num_elements = 8
  25. all_tensors = [
  26. torch.arange(num_elements, dtype=torch.float32, device="cuda") *
  27. (r + 1) for r in range(tensor_parallel_size)
  28. ]
  29. expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
  30. t = all_tensors[rank]
  31. t = tensor_model_parallel_all_reduce(t)
  32. assert torch.allclose(t, expected)
  33. @ray.remote(num_gpus=1, max_calls=1)
  34. def all_gather_test_worker(tensor_parallel_size: int, rank: int,
  35. distributed_init_port: str):
  36. # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
  37. # so that each worker can see all the GPUs
  38. # they will be able to set the device to the correct GPU
  39. del os.environ["CUDA_VISIBLE_DEVICES"]
  40. device = torch.device(f"cuda:{rank}")
  41. torch.cuda.set_device(device)
  42. init_test_distributed_environment(1, tensor_parallel_size, rank,
  43. distributed_init_port)
  44. num_dimensions = 3
  45. tensor_size = list(range(2, num_dimensions + 2))
  46. total_size = 1
  47. for s in tensor_size:
  48. total_size *= s
  49. for all_gather_dimension in range(num_dimensions):
  50. all_tensors = [
  51. torch.arange(total_size, dtype=torch.float32,
  52. device="cuda").reshape(tensor_size) * (r + 1)
  53. for r in range(tensor_parallel_size)
  54. ]
  55. expected = torch.cat(all_tensors, dim=all_gather_dimension)
  56. t = all_tensors[rank]
  57. t = tensor_model_parallel_all_gather(t, all_gather_dimension)
  58. assert torch.allclose(t, expected)
  59. @ray.remote(num_gpus=1, max_calls=1)
  60. def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
  61. distributed_init_port: str):
  62. # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
  63. # so that each worker can see all the GPUs
  64. # they will be able to set the device to the correct GPU
  65. del os.environ["CUDA_VISIBLE_DEVICES"]
  66. device = torch.device(f"cuda:{rank}")
  67. torch.cuda.set_device(device)
  68. init_test_distributed_environment(1, tensor_parallel_size, rank,
  69. distributed_init_port)
  70. test_dict = {
  71. "a": torch.arange(8, dtype=torch.float32, device="cuda"),
  72. "b": torch.arange(16, dtype=torch.int8, device="cuda"),
  73. "c": "test",
  74. "d": [1, 2, 3],
  75. "e": {
  76. "a": 1,
  77. "b": 2
  78. },
  79. }
  80. if rank == 0:
  81. broadcast_tensor_dict(test_dict, src=0)
  82. else:
  83. recv_dict = broadcast_tensor_dict(src=0)
  84. assert len(recv_dict) == len(test_dict)
  85. assert torch.allclose(recv_dict["a"], test_dict["a"])
  86. assert torch.allclose(recv_dict["b"], test_dict["b"])
  87. assert recv_dict["c"] == test_dict["c"]
  88. assert recv_dict["d"] == test_dict["d"]
  89. assert recv_dict["e"] == test_dict["e"]
  90. @pytest.mark.skipif(torch.cuda.device_count() < 2,
  91. reason="Need at least 2 GPUs to run the test.")
  92. @pytest.mark.parametrize("tensor_parallel_size", [2])
  93. @pytest.mark.parametrize("test_target", [
  94. all_reduce_test_worker, all_gather_test_worker,
  95. broadcast_tensor_dict_test_worker
  96. ])
  97. def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
  98. multi_process_tensor_parallel(tensor_parallel_size, test_target)