custom_all_reduce.cu 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #include <ATen/cuda/Exceptions.h>
  2. #include <c10/cuda/CUDAGuard.h>
  3. #include <c10/cuda/CUDAStream.h>
  4. #include <torch/all.h>
  5. #include "custom_all_reduce.cuh"
  6. // fake pointer type, must match fptr_t type in ops.h
  7. using fptr_t = int64_t;
  8. static_assert(sizeof(void*) == sizeof(fptr_t));
  9. fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
  10. const std::vector<std::string>& handles,
  11. const std::vector<int64_t>& offsets, int64_t rank,
  12. bool full_nvlink) {
  13. int world_size = offsets.size();
  14. if (world_size > 8)
  15. throw std::invalid_argument("world size > 8 is not supported");
  16. if (world_size % 2 != 0)
  17. throw std::invalid_argument("Odd num gpus is not supported for now");
  18. if (world_size != handles.size())
  19. throw std::invalid_argument(
  20. "handles length should equal to offsets length");
  21. if (rank < 0 || rank >= world_size)
  22. throw std::invalid_argument("invalid rank passed in");
  23. cudaIpcMemHandle_t ipc_handles[8];
  24. for (int i = 0; i < world_size; i++) {
  25. std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
  26. }
  27. return (fptr_t) new aphrodite::CustomAllreduce(
  28. reinterpret_cast<aphrodite::Signal*>(meta.data_ptr()),
  29. rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank,
  30. full_nvlink);
  31. }
  32. /**
  33. * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
  34. * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
  35. * because it allows transpose of contiguous slice (i.e. slicing the first
  36. * dimension). Currently, we require this because stride information is not
  37. * passed into the kernels and we treat input tensors as flat.
  38. *
  39. * Examples
  40. * A = torch.zeros(3, 3, 3)
  41. * 1. A: OK
  42. * 2. A[1:]: OK
  43. * 3. A.permute(2, 0, 1): OK
  44. * 4. A[1:].permute(2, 0, 1): OK
  45. * 5. A[None].expand(2, -1, -1, -1): Not OK
  46. * 6. A[:, 1:, 1:]: Not OK
  47. */
  48. bool _is_weak_contiguous(torch::Tensor& t) {
  49. return t.is_contiguous() ||
  50. (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
  51. t.numel() * t.element_size());
  52. }
  53. void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
  54. cudaStream_t stream) {
  55. auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
  56. TORCH_CHECK(_is_weak_contiguous(out));
  57. switch (out.scalar_type()) {
  58. case at::ScalarType::Float: {
  59. fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
  60. reinterpret_cast<float*>(out.data_ptr()),
  61. out.numel());
  62. break;
  63. }
  64. case at::ScalarType::Half: {
  65. fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
  66. reinterpret_cast<half*>(out.data_ptr()), out.numel());
  67. break;
  68. }
  69. #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
  70. case at::ScalarType::BFloat16: {
  71. fa->allreduce<nv_bfloat16>(
  72. stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
  73. reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
  74. break;
  75. }
  76. #endif
  77. default:
  78. throw std::runtime_error(
  79. "custom allreduce only supports float32, float16 and bfloat16");
  80. }
  81. }
  82. void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
  83. const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  84. auto stream = c10::cuda::getCurrentCUDAStream().stream();
  85. TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  86. TORCH_CHECK_EQ(inp.numel(), out.numel());
  87. _all_reduce(_fa, inp, out, stream);
  88. }
  89. void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
  90. torch::Tensor& out) {
  91. const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  92. auto stream = c10::cuda::getCurrentCUDAStream().stream();
  93. auto input_size = inp.numel() * inp.element_size();
  94. TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  95. TORCH_CHECK_EQ(inp.numel(), out.numel());
  96. TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
  97. "registered buffer is too small to contain the input");
  98. AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
  99. input_size, cudaMemcpyDeviceToDevice, stream));
  100. _all_reduce(_fa, reg_buffer, out, stream);
  101. }
  102. void dispose(fptr_t _fa) {
  103. auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
  104. delete fa;
  105. }
  106. int64_t meta_size() { return sizeof(aphrodite::Signal); }
  107. void register_buffer(fptr_t _fa, torch::Tensor& t,
  108. const std::vector<std::string>& handles,
  109. const std::vector<int64_t>& offsets) {
  110. auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
  111. fa->register_buffer(handles, offsets, t.data_ptr());
  112. }
  113. std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
  114. fptr_t _fa) {
  115. auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
  116. auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
  117. auto options =
  118. torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
  119. auto handles =
  120. torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
  121. std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
  122. return {handles, std::move(offsets)};
  123. }
  124. void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
  125. const std::vector<std::vector<int64_t>>& offsets) {
  126. auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
  127. fa->register_graph_buffers(handles, offsets);
  128. }