123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- using fptr_t = int64_t;
- static_assert(sizeof(void*) == sizeof(fptr_t));
- fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
- const std::vector<std::string>& handles,
- const std::vector<int64_t>& offsets, int64_t rank,
- bool full_nvlink) {
- int world_size = offsets.size();
- if (world_size > 8)
- throw std::invalid_argument("world size > 8 is not supported");
- if (world_size % 2 != 0)
- throw std::invalid_argument("Odd num gpus is not supported for now");
- if (world_size != handles.size())
- throw std::invalid_argument(
- "handles length should equal to offsets length");
- if (rank < 0 || rank >= world_size)
- throw std::invalid_argument("invalid rank passed in");
- cudaIpcMemHandle_t ipc_handles[8];
- for (int i = 0; i < world_size; i++) {
- std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
- }
- return (fptr_t) new aphrodite::CustomAllreduce(
- reinterpret_cast<aphrodite::Signal*>(meta.data_ptr()),
- rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank,
- full_nvlink);
- }
- bool _is_weak_contiguous(torch::Tensor& t) {
- return t.is_contiguous() ||
- (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
- t.numel() * t.element_size());
- }
- bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
- bool full_nvlink) {
- auto inp_size = inp.numel() * inp.element_size();
-
- if (inp_size % 16 != 0) return false;
- if (!_is_weak_contiguous(inp)) return false;
- if (world_size == 2 || full_nvlink) return inp_size <= max_size;
-
-
- return false;
- }
- void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
- cudaStream_t stream) {
- auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
- TORCH_CHECK(_is_weak_contiguous(out));
- switch (out.scalar_type()) {
- case at::ScalarType::Float: {
- fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
- reinterpret_cast<float*>(out.data_ptr()),
- out.numel());
- break;
- }
- case at::ScalarType::Half: {
- fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
- reinterpret_cast<half*>(out.data_ptr()), out.numel());
- break;
- }
- case at::ScalarType::BFloat16: {
- fa->allreduce<nv_bfloat16>(
- stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
- reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
- break;
- }
- default:
- throw std::runtime_error(
- "custom allreduce only supports float32, float16 and bfloat16");
- }
- }
- void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
- auto stream = c10::cuda::getCurrentCUDAStream().stream();
- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
- TORCH_CHECK_EQ(inp.numel(), out.numel());
- _all_reduce(_fa, inp, out, stream);
- }
- void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
- torch::Tensor& out) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
- auto stream = c10::cuda::getCurrentCUDAStream().stream();
- auto input_size = inp.numel() * inp.element_size();
- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
- TORCH_CHECK_EQ(inp.numel(), out.numel());
- TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
- "registered buffer is too small to contain the input");
- AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
- input_size, cudaMemcpyDeviceToDevice, stream));
- _all_reduce(_fa, reg_buffer, out, stream);
- }
- void dispose(fptr_t _fa) {
- auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
- delete fa;
- }
- int64_t meta_size() { return sizeof(aphrodite::Signal); }
- void register_buffer(fptr_t _fa, torch::Tensor& t,
- const std::vector<std::string>& handles,
- const std::vector<int64_t>& offsets) {
- auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
- fa->register_buffer(handles, offsets, t.data_ptr());
- }
- std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
- fptr_t _fa) {
- auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
- auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
- auto options =
- torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
- auto handles =
- torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
- std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
- return {handles, std::move(offsets)};
- }
- void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
- const std::vector<std::vector<int64_t>>& offsets) {
- auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
- fa->register_graph_buffers(handles, offsets);
- }
|