custom_all_reduce.cu 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #include <ATen/cuda/Exceptions.h>
  2. #include <c10/cuda/CUDAGuard.h>
  3. #include <c10/cuda/CUDAStream.h>
  4. #include <torch/extension.h>
  5. #include "custom_all_reduce.cuh"
  6. // fake pointer type
  7. using fptr_t = uint64_t;
  8. static_assert(sizeof(void *) == sizeof(fptr_t));
  9. fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
  10. const std::vector<std::string> &handles,
  11. const std::vector<int64_t> &offsets, int rank,
  12. bool full_nvlink) {
  13. int world_size = offsets.size();
  14. if (world_size > 8)
  15. throw std::invalid_argument("world size > 8 is not supported");
  16. if (world_size % 2 != 0)
  17. throw std::invalid_argument("Odd num gpus is not supported for now");
  18. if (world_size != handles.size())
  19. throw std::invalid_argument(
  20. "handles length should equal to offsets length");
  21. if (rank < 0 || rank >= world_size)
  22. throw std::invalid_argument("invalid rank passed in");
  23. cudaIpcMemHandle_t ipc_handles[8];
  24. for (int i = 0; i < world_size; i++) {
  25. std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
  26. }
  27. return (fptr_t) new aphrodite::CustomAllreduce(
  28. reinterpret_cast<aphrodite::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
  29. rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
  30. }
  31. /**
  32. * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
  33. * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
  34. * because it allows transpose of contiguous slice (i.e. slicing the first
  35. * dimension). Currently, we require this because stride information is not
  36. * passed into the kernels and we treat input tensors as flat.
  37. *
  38. * Examples
  39. * A = torch.zeros(3, 3, 3)
  40. * 1. A: OK
  41. * 2. A[1:]: OK
  42. * 3. A.permute(2, 0, 1): OK
  43. * 4. A[1:].permute(2, 0, 1): OK
  44. * 5. A[None].expand(2, -1, -1, -1): Not OK
  45. * 6. A[:, 1:, 1:]: Not OK
  46. */
  47. bool _is_weak_contiguous(torch::Tensor &t) {
  48. return t.is_contiguous() ||
  49. (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
  50. t.numel() * t.element_size());
  51. }
  52. bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
  53. bool full_nvlink) {
  54. auto inp_size = inp.numel() * inp.element_size();
  55. // custom allreduce requires input byte size to be multiples of 16
  56. if (inp_size % 16 != 0) return false;
  57. if (!_is_weak_contiguous(inp)) return false;
  58. if (world_size == 2 || full_nvlink) return inp_size <= max_size;
  59. // for 4 or more non NVLink-capable GPUs, custom allreduce provides little
  60. // performance improvement over NCCL.
  61. return false;
  62. }
  63. void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
  64. cudaStream_t stream) {
  65. auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
  66. TORCH_CHECK(_is_weak_contiguous(out));
  67. switch (out.scalar_type()) {
  68. case at::ScalarType::Float: {
  69. fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
  70. reinterpret_cast<float *>(out.data_ptr()),
  71. out.numel());
  72. break;
  73. }
  74. case at::ScalarType::Half: {
  75. fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
  76. reinterpret_cast<half *>(out.data_ptr()),
  77. out.numel());
  78. break;
  79. }
  80. #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
  81. case at::ScalarType::BFloat16: {
  82. fa->allreduce<nv_bfloat16>(
  83. stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
  84. reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
  85. break;
  86. }
  87. #endif
  88. default:
  89. throw std::runtime_error(
  90. "custom allreduce only supports float32, float16 and bfloat16");
  91. }
  92. }
  93. void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
  94. const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  95. auto stream = c10::cuda::getCurrentCUDAStream().stream();
  96. TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  97. TORCH_CHECK_EQ(inp.numel(), out.numel());
  98. _all_reduce(_fa, inp, out, stream);
  99. }
  100. void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
  101. torch::Tensor &out) {
  102. const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  103. auto stream = c10::cuda::getCurrentCUDAStream().stream();
  104. auto input_size = inp.numel() * inp.element_size();
  105. TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  106. TORCH_CHECK_EQ(inp.numel(), out.numel());
  107. TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
  108. "registered buffer is too small to contain the input");
  109. AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
  110. input_size, cudaMemcpyDeviceToDevice, stream));
  111. _all_reduce(_fa, reg_buffer, out, stream);
  112. }
  113. void dispose(fptr_t _fa) {
  114. auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
  115. delete fa;
  116. }
  117. int meta_size() { return sizeof(aphrodite::Signal); }
  118. void register_buffer(fptr_t _fa, torch::Tensor &t,
  119. const std::vector<std::string> &handles,
  120. const std::vector<int64_t> &offsets) {
  121. auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
  122. fa->register_buffer(handles, offsets, t.data_ptr());
  123. }
  124. std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
  125. fptr_t _fa) {
  126. auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
  127. return fa->get_graph_buffer_ipc_meta();
  128. }
  129. void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
  130. const std::vector<std::vector<int64_t>> &offsets) {
  131. auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
  132. fa->register_graph_buffers(handles, offsets);
  133. }